Browse Source

feat: UDP support

Yujia Qiao 4 years ago
parent
commit
443f763800
11 changed files with 574 additions and 126 deletions
  1. 3 0
      Cargo.lock
  2. 1 1
      Cargo.toml
  3. 7 0
      examples/udp/client.toml
  4. 7 0
      examples/udp/server.toml
  5. 165 11
      src/client.rs
  6. 18 0
      src/config.rs
  7. 15 0
      src/constants.rs
  8. 82 2
      src/helper.rs
  9. 1 0
      src/lib.rs
  10. 78 3
      src/protocol.rs
  11. 197 109
      src/server.rs

+ 3 - 0
Cargo.lock

@@ -88,6 +88,9 @@ name = "bytes"
 version = "1.1.0"
 source = "registry+https://github.com/rust-lang/crates.io-index"
 checksum = "c4872d67bab6358e59559027aa3b9157c53d9358c51423c17554809a8858e0f8"
+dependencies = [
+ "serde",
+]
 
 [[package]]
 name = "cc"

+ 1 - 1
Cargo.toml

@@ -24,7 +24,7 @@ opt-level = "s"
 
 [dependencies]
 tokio = { version = "1", features = ["full"] }
-bytes = { version = "1"}
+bytes = { version = "1", features = ["serde"] }
 clap = { version = "3.0.0-rc.7", features = ["derive"] }
 toml = "0.5"
 serde = { version = "1.0", features = ["derive"] }

+ 7 - 0
examples/udp/client.toml

@@ -0,0 +1,7 @@
+[client]
+remote_addr = "localhost:2333"
+default_token = "123"
+
+[client.services.foo1]
+type = "udp"
+local_addr = "127.0.0.1:80"

+ 7 - 0
examples/udp/server.toml

@@ -0,0 +1,7 @@
+[server]
+bind_addr = "0.0.0.0:2333"
+default_token = "123"
+
+[server.services.foo1]
+type = "udp"
+bind_addr = "0.0.0.0:5202"

+ 165 - 11
src/client.rs

@@ -1,23 +1,28 @@
 use crate::config::{ClientConfig, ClientServiceConfig, Config, TransportType};
+use crate::helper::udp_connect;
 use crate::protocol::Hello::{self, *};
 use crate::protocol::{
     self, read_ack, read_control_cmd, read_data_cmd, read_hello, Ack, Auth, ControlChannelCmd,
-    DataChannelCmd, CURRENT_PROTO_VRESION, HASH_WIDTH_IN_BYTES,
+    DataChannelCmd, UdpTraffic, CURRENT_PROTO_VRESION, HASH_WIDTH_IN_BYTES,
 };
 use crate::transport::{TcpTransport, Transport};
 use anyhow::{anyhow, bail, Context, Result};
 use backoff::ExponentialBackoff;
+use bytes::{Bytes, BytesMut};
 use std::collections::HashMap;
+use std::net::SocketAddr;
 use std::sync::Arc;
-use tokio::io::{copy_bidirectional, AsyncWriteExt};
-use tokio::net::TcpStream;
-use tokio::sync::{broadcast, oneshot};
+use tokio::io::{self, copy_bidirectional, AsyncWriteExt};
+use tokio::net::{TcpStream, UdpSocket};
+use tokio::sync::{broadcast, mpsc, oneshot, RwLock};
 use tokio::time::{self, Duration};
 use tracing::{debug, error, info, instrument, Instrument, Span};
 
 #[cfg(feature = "tls")]
 use crate::transport::TlsTransport;
 
+use crate::constants::{UDP_BUFFER_SIZE, UDP_SENDQ_SIZE, UDP_TIMEOUT};
+
 // The entrypoint of running a client
 pub async fn run_client(config: &Config, shutdown_rx: broadcast::Receiver<bool>) -> Result<()> {
     let config = match &config.client {
@@ -112,7 +117,9 @@ struct RunDataChannelArgs<T: Transport> {
     connector: Arc<T>,
 }
 
-async fn run_data_channel<T: Transport>(args: Arc<RunDataChannelArgs<T>>) -> Result<()> {
+async fn do_data_channel_handshake<T: Transport>(
+    args: Arc<RunDataChannelArgs<T>>,
+) -> Result<T::Stream> {
     // Retry at least every 100ms, at most for 10 seconds
     let backoff = ExponentialBackoff {
         max_interval: Duration::from_millis(100),
@@ -135,18 +142,165 @@ async fn run_data_channel<T: Transport>(args: Arc<RunDataChannelArgs<T>>) -> Res
     let hello = Hello::DataChannelHello(CURRENT_PROTO_VRESION, v.to_owned());
     conn.write_all(&bincode::serialize(&hello).unwrap()).await?;
 
+    Ok(conn)
+}
+
+async fn run_data_channel<T: Transport>(args: Arc<RunDataChannelArgs<T>>) -> Result<()> {
+    // Do the handshake
+    let mut conn = do_data_channel_handshake(args.clone()).await?;
+
     // Forward
     match read_data_cmd(&mut conn).await? {
-        DataChannelCmd::StartForward => {
-            let mut local = TcpStream::connect(&args.local_addr)
-                .await
-                .with_context(|| "Failed to conenct to local_addr")?;
-            let _ = copy_bidirectional(&mut conn, &mut local).await;
+        DataChannelCmd::StartForwardTcp => {
+            run_data_channel_for_tcp::<T>(conn, &args.local_addr).await?;
+        }
+        DataChannelCmd::StartForwardUdp => {
+            run_data_channel_for_udp::<T>(conn, &args.local_addr).await?;
         }
     }
     Ok(())
 }
 
+// Simply copying back and forth for TCP
+#[instrument(skip(conn))]
+async fn run_data_channel_for_tcp<T: Transport>(
+    mut conn: T::Stream,
+    local_addr: &str,
+) -> Result<()> {
+    debug!("New data channel starts forwarding");
+
+    let mut local = TcpStream::connect(local_addr)
+        .await
+        .with_context(|| "Failed to conenct to local_addr")?;
+    let _ = copy_bidirectional(&mut conn, &mut local).await;
+    Ok(())
+}
+
+// Things get a little tricker when it gets to UDP because it's connectionless.
+// A UdpPortMap must be maintained for recent seen incoming address, giving them
+// each a local port, which is associated with a socket. So just the sender
+// to the socket will work fine for the map's value.
+type UdpPortMap = Arc<RwLock<HashMap<SocketAddr, mpsc::Sender<Bytes>>>>;
+
+#[instrument(skip(conn))]
+async fn run_data_channel_for_udp<T: Transport>(conn: T::Stream, local_addr: &str) -> Result<()> {
+    debug!("New data channel starts forwarding");
+
+    let port_map: UdpPortMap = Arc::new(RwLock::new(HashMap::new()));
+
+    // The channel stores UdpTraffic that needs to be sent to the server
+    let (outbound_tx, mut outbound_rx) = mpsc::channel::<UdpTraffic>(UDP_SENDQ_SIZE);
+
+    // FIXME: https://github.com/tokio-rs/tls/issues/40
+    // Maybe this is our concern
+    let (mut rd, mut wr) = io::split(conn);
+
+    // Keep sending items from the outbound channel to the server
+    tokio::spawn(async move {
+        while let Some(t) = outbound_rx.recv().await {
+            debug!("outbound {:?}", t);
+            if t.write(&mut wr).await.is_err() {
+                break;
+            }
+        }
+    });
+
+    loop {
+        // Read a packet from the server
+        let packet = UdpTraffic::read(&mut rd).await?;
+        let m = port_map.read().await;
+
+        if m.get(&packet.from).is_none() {
+            // This packet is from a address we don't see for a while,
+            // which is not in the UdpPortMap.
+            // So set up a mapping (and a forwarder) for it
+
+            // Drop the reader lock
+            drop(m);
+
+            // Grab the writer lock
+            // This is the only thread that will try to grab the writer lock
+            // So no need to worry about some other thread has already set up
+            // the mapping between the gap of dropping the reader lock and
+            // grabbing the writer lock
+            let mut m = port_map.write().await;
+
+            match udp_connect(local_addr).await {
+                Ok(s) => {
+                    let (inbound_tx, inbound_rx) = mpsc::channel(UDP_SENDQ_SIZE);
+                    m.insert(packet.from, inbound_tx);
+                    tokio::spawn(run_udp_forwarder(
+                        s,
+                        inbound_rx,
+                        outbound_tx.clone(),
+                        packet.from,
+                        port_map.clone(),
+                    ));
+                }
+                Err(e) => {
+                    error!("{:?}", e);
+                }
+            }
+        }
+
+        // Now there should be a udp forwarder that can receive the packet
+        let m = port_map.read().await;
+        if let Some(tx) = m.get(&packet.from) {
+            let _ = tx.send(packet.data).await;
+        }
+    }
+}
+
+// Run a UdpSocket for the visitor `from`
+async fn run_udp_forwarder(
+    s: UdpSocket,
+    mut inbound_rx: mpsc::Receiver<Bytes>,
+    outbount_tx: mpsc::Sender<UdpTraffic>,
+    from: SocketAddr,
+    port_map: UdpPortMap,
+) -> Result<()> {
+    let mut buf = BytesMut::new();
+    buf.resize(UDP_BUFFER_SIZE, 0);
+
+    loop {
+        tokio::select! {
+            // Receive from the server
+            data = inbound_rx.recv() => {
+                if let Some(data) = data {
+                    s.send(&data).await?;
+                } else {
+                    break;
+                }
+            },
+
+            // Receive from the service
+            val = s.recv(&mut buf) => {
+                let len = match val {
+                    Ok(v) => v,
+                    Err(_) => {break;}
+                };
+
+                let t = UdpTraffic{
+                    from,
+                    data: Bytes::copy_from_slice(&buf[..len])
+                };
+
+                outbount_tx.send(t).await?;
+            },
+
+            // No traffic for the duration of UDP_TIMEOUT, clean up the state
+            _ = time::sleep(Duration::from_secs(UDP_TIMEOUT)) => {
+                break;
+            }
+        }
+    }
+
+    let mut port_map = port_map.write().await;
+    port_map.remove(&from);
+
+    Ok(())
+}
+
 // Control channel, using T as the transport layer
 struct ControlChannel<T: Transport> {
     digest: ServiceDigest,              // SHA256 of the service name
@@ -163,7 +317,7 @@ struct ControlChannelHandle {
 }
 
 impl<T: 'static + Transport> ControlChannel<T> {
-    #[instrument(skip(self), fields(service=%self.service.name))]
+    #[instrument(skip_all)]
     async fn run(&mut self) -> Result<()> {
         let mut conn = self
             .transport

+ 18 - 0
src/config.rs

@@ -20,14 +20,30 @@ impl Default for TransportType {
 
 #[derive(Debug, Serialize, Deserialize, Clone)]
 pub struct ClientServiceConfig {
+    #[serde(rename = "type", default = "default_service_type")]
+    pub service_type: ServiceType,
     #[serde(skip)]
     pub name: String,
     pub local_addr: String,
     pub token: Option<String>,
 }
 
+#[derive(Debug, Serialize, Deserialize, Clone, Copy)]
+pub enum ServiceType {
+    #[serde(rename = "tcp")]
+    Tcp,
+    #[serde(rename = "udp")]
+    Udp,
+}
+
+fn default_service_type() -> ServiceType {
+    ServiceType::Tcp
+}
+
 #[derive(Debug, Serialize, Deserialize, Clone)]
 pub struct ServerServiceConfig {
+    #[serde(rename = "type", default = "default_service_type")]
+    pub service_type: ServiceType,
     #[serde(skip)]
     pub name: String,
     pub bind_addr: String,
@@ -231,6 +247,7 @@ mod tests {
         cfg.services.insert(
             "foo1".into(),
             ServerServiceConfig {
+                service_type: ServiceType::Tcp,
                 name: "foo1".into(),
                 bind_addr: "127.0.0.1:80".into(),
                 token: None,
@@ -277,6 +294,7 @@ mod tests {
         cfg.services.insert(
             "foo1".into(),
             ClientServiceConfig {
+                service_type: ServiceType::Tcp,
                 name: "foo1".into(),
                 local_addr: "127.0.0.1:80".into(),
                 token: None,

+ 15 - 0
src/constants.rs

@@ -0,0 +1,15 @@
+use backoff::ExponentialBackoff;
+use std::time::Duration;
+
+// FIXME: Determine reasonable size
+pub const UDP_BUFFER_SIZE: usize = 2048;
+pub const UDP_SENDQ_SIZE: usize = 1024;
+pub const UDP_TIMEOUT: u64 = 60;
+
+pub fn listen_backoff() -> ExponentialBackoff {
+    ExponentialBackoff {
+        max_elapsed_time: None,
+        max_interval: Duration::from_secs(1),
+        ..Default::default()
+    }
+}

+ 82 - 2
src/helper.rs

@@ -1,8 +1,13 @@
-use std::time::Duration;
+use std::{
+    collections::hash_map::DefaultHasher,
+    hash::{Hash, Hasher},
+    net::SocketAddr,
+    time::Duration,
+};
 
 use anyhow::{Context, Result};
 use socket2::{SockRef, TcpKeepalive};
-use tokio::net::TcpStream;
+use tokio::net::{TcpStream, ToSocketAddrs, UdpSocket};
 
 // Tokio hesitates to expose this option...So we have to do it on our own :(
 // The good news is that using socket2 it can be easily done, without losing portablity.
@@ -21,3 +26,78 @@ pub fn feature_not_compile(feature: &str) -> ! {
         feature
     )
 }
+
+/// Create a UDP socket and connect to `addr`
+pub async fn udp_connect<A: ToSocketAddrs>(addr: A) -> Result<UdpSocket> {
+    // FIXME: This only works for IPv4
+    let s = UdpSocket::bind("0.0.0.0:0").await?;
+    s.connect(addr).await?;
+    Ok(s)
+}
+
+#[allow(dead_code)]
+pub fn hash_socket_addr(a: &SocketAddr) -> u64 {
+    let mut hasher = DefaultHasher::new();
+    a.hash(&mut hasher);
+    hasher.finish()
+}
+
+// Wait for the stablization of https://doc.rust-lang.org/std/primitive.i64.html#method.log2
+#[allow(dead_code)]
+fn log2_floor(x: usize) -> u8 {
+    (x as f64).log2().floor() as u8
+}
+
+#[allow(dead_code)]
+pub fn floor_to_pow_of_2(x: usize) -> usize {
+    if x == 1 {
+        return 1;
+    }
+    let w = log2_floor(x);
+    1 << w
+}
+
+#[cfg(test)]
+mod test {
+    use crate::helper::{floor_to_pow_of_2, log2_floor};
+
+    #[test]
+    fn test_log2_floor() {
+        let t = [
+            (2, 1),
+            (3, 1),
+            (4, 2),
+            (8, 3),
+            (9, 3),
+            (15, 3),
+            (16, 4),
+            (1023, 9),
+            (1024, 10),
+            (2000, 10),
+            (2048, 11),
+        ];
+        for t in t {
+            assert_eq!(log2_floor(t.0), t.1);
+        }
+    }
+
+    #[test]
+    fn test_floor_to_pow_of_2() {
+        let t = [
+            (1 as usize, 1 as usize),
+            (2, 2),
+            (3, 2),
+            (4, 4),
+            (5, 4),
+            (15, 8),
+            (31, 16),
+            (33, 32),
+            (1000, 512),
+            (1500, 1024),
+            (2300, 2048),
+        ];
+        for t in t {
+            assert_eq!(floor_to_pow_of_2(t.0), t.1);
+        }
+    }
+}

+ 1 - 0
src/lib.rs

@@ -1,5 +1,6 @@
 mod cli;
 mod config;
+mod constants;
 mod helper;
 mod multi_map;
 mod protocol;

+ 78 - 3
src/protocol.rs

@@ -1,9 +1,11 @@
 pub const HASH_WIDTH_IN_BYTES: usize = 32;
 
 use anyhow::{Context, Result};
+use bytes::{Bytes, BytesMut};
 use lazy_static::lazy_static;
 use serde::{Deserialize, Serialize};
-use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite};
+use std::net::SocketAddr;
+use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt};
 
 type ProtocolVersion = u8;
 const PROTO_V0: u8 = 0u8;
@@ -49,7 +51,80 @@ pub enum ControlChannelCmd {
 
 #[derive(Deserialize, Serialize, Debug)]
 pub enum DataChannelCmd {
-    StartForward,
+    StartForwardTcp,
+    StartForwardUdp,
+}
+
+type UdpPacketLen = u16; // `u16` should be enough for any practical UDP traffic on the Internet
+#[derive(Deserialize, Serialize, Debug)]
+struct UdpHeader {
+    from: SocketAddr,
+    len: UdpPacketLen,
+}
+
+#[derive(Debug)]
+pub struct UdpTraffic {
+    pub from: SocketAddr,
+    pub data: Bytes,
+}
+
+impl UdpTraffic {
+    pub async fn write<T: AsyncWrite + Unpin>(&self, writer: &mut T) -> Result<()> {
+        let v = bincode::serialize(&UdpHeader {
+            from: self.from,
+            len: self.data.len() as UdpPacketLen,
+        })
+        .unwrap();
+
+        writer.write_u16(v.len() as u16).await?;
+        writer.write_all(&v).await?;
+
+        writer.write_all(&self.data).await?;
+
+        Ok(())
+    }
+
+    #[allow(dead_code)]
+    pub async fn write_slice<T: AsyncWrite + Unpin>(
+        writer: &mut T,
+        from: SocketAddr,
+        data: &[u8],
+    ) -> Result<()> {
+        let v = bincode::serialize(&UdpHeader {
+            from,
+            len: data.len() as UdpPacketLen,
+        })
+        .unwrap();
+
+        writer.write_u16(v.len() as u16).await?;
+        writer.write_all(&v).await?;
+
+        writer.write_all(data).await?;
+
+        Ok(())
+    }
+
+    pub async fn read<T: AsyncRead + Unpin>(reader: &mut T) -> Result<UdpTraffic> {
+        let len = reader.read_u16().await? as usize;
+
+        let mut buf = Vec::new();
+        buf.resize(len, 0);
+        reader
+            .read_exact(&mut buf)
+            .await
+            .with_context(|| "Failed to read udp header")?;
+        let header: UdpHeader =
+            bincode::deserialize(&buf).with_context(|| "Failed to deserialize udp header")?;
+
+        let mut data = BytesMut::new();
+        data.resize(header.len as usize, 0);
+        reader.read_exact(&mut data).await?;
+
+        Ok(UdpTraffic {
+            from: header.from,
+            data: data.freeze(),
+        })
+    }
 }
 
 pub fn digest(data: &[u8]) -> Digest {
@@ -74,7 +149,7 @@ impl PacketLength {
             .unwrap() as usize;
         let c_cmd =
             bincode::serialized_size(&ControlChannelCmd::CreateDataChannel).unwrap() as usize;
-        let d_cmd = bincode::serialized_size(&DataChannelCmd::StartForward).unwrap() as usize;
+        let d_cmd = bincode::serialized_size(&DataChannelCmd::StartForwardTcp).unwrap() as usize;
         let ack = Ack::Ok;
         let ack = bincode::serialized_size(&ack).unwrap() as usize;
 

+ 197 - 109
src/server.rs

@@ -1,8 +1,10 @@
-use crate::config::{Config, ServerConfig, ServerServiceConfig, TransportType};
+use crate::config::{Config, ServerConfig, ServerServiceConfig, ServiceType, TransportType};
+use crate::constants::{listen_backoff, UDP_BUFFER_SIZE};
 use crate::multi_map::MultiMap;
 use crate::protocol::Hello::{ControlChannelHello, DataChannelHello};
 use crate::protocol::{
-    self, read_auth, read_hello, Ack, ControlChannelCmd, DataChannelCmd, Hello, HASH_WIDTH_IN_BYTES,
+    self, read_auth, read_hello, Ack, ControlChannelCmd, DataChannelCmd, Hello, UdpTraffic,
+    HASH_WIDTH_IN_BYTES,
 };
 #[cfg(feature = "tls")]
 use crate::transport::TlsTransport;
@@ -10,21 +12,23 @@ use crate::transport::{TcpTransport, Transport};
 use anyhow::{anyhow, bail, Context, Result};
 use backoff::backoff::Backoff;
 use backoff::ExponentialBackoff;
+
 use rand::RngCore;
 use std::collections::HashMap;
 use std::net::SocketAddr;
 use std::sync::Arc;
 use std::time::Duration;
 use tokio::io::{self, copy_bidirectional, AsyncWriteExt};
-use tokio::net::{TcpListener, TcpStream};
-use tokio::sync::{broadcast, mpsc, oneshot, RwLock};
+use tokio::net::{TcpListener, TcpStream, UdpSocket};
+use tokio::sync::{broadcast, mpsc, RwLock};
 use tokio::time;
-use tracing::{debug, error, info, info_span, warn, Instrument};
+use tracing::{debug, error, info, info_span, instrument, warn, Instrument, Span};
 
 type ServiceDigest = protocol::Digest; // SHA256 of a service name
 type Nonce = protocol::Digest; // Also called `session_key`
 
-const POOL_SIZE: usize = 64; // The number of cached connections
+const TCP_POOL_SIZE: usize = 64; // The number of cached connections for TCP servies
+const UDP_POOL_SIZE: usize = 2; // The number of cached connections for UDP services
 const CHAN_SIZE: usize = 2048; // The capacity of various chans
 
 // The entrypoint of running a server
@@ -268,7 +272,7 @@ async fn do_control_channel_handshake<T: 'static + Transport>(
     Ok(())
 }
 
-async fn do_data_channel_handshake<T: Transport>(
+async fn do_data_channel_handshake<T: 'static + Transport>(
     conn: T::Stream,
     control_channels: Arc<RwLock<ControlChannelMap<T>>>,
     nonce: Nonce,
@@ -276,9 +280,9 @@ async fn do_data_channel_handshake<T: Transport>(
     // Validate
     let control_channels_guard = control_channels.read().await;
     match control_channels_guard.get2(&nonce) {
-        Some(c_ch) => {
+        Some(handle) => {
             // Send the data channel to the corresponding control channel
-            c_ch.conn_pool.data_ch_tx.send(conn).await?;
+            handle.data_ch_tx.send(conn).await?;
         }
         None => {
             // TODO: Maybe print IP here
@@ -288,43 +292,74 @@ async fn do_data_channel_handshake<T: Transport>(
     Ok(())
 }
 
-// Control channel, using T as the transport layer
-struct ControlChannel<T: Transport> {
-    conn: T::Stream,                      // The connection of control channel
-    service: ServerServiceConfig,         // A copy of the corresponding service config
-    shutdown_rx: oneshot::Receiver<bool>, // Receives the shutdown signal
-    visitor_tx: mpsc::Sender<TcpStream>,  // Receives visitor connections
-}
-
-// The handle of a control channel, along with the handle of a connection pool
-// Dropping it will drop the actual control channel, because `visitor_tx`
-// and `shutdown_tx` are closed
-struct ControlChannelHandle<T: Transport> {
+pub struct ControlChannelHandle<T: Transport> {
     // Shutdown the control channel.
     // Not used for now, but can be used for hot reloading
-    _shutdown_tx: oneshot::Sender<bool>,
-    conn_pool: ConnectionPoolHandle<T>,
+    #[allow(dead_code)]
+    shutdown_tx: broadcast::Sender<bool>,
+    //data_ch_req_tx: mpsc::Sender<bool>,
+    data_ch_tx: mpsc::Sender<T::Stream>,
 }
 
-impl<T: 'static + Transport> ControlChannelHandle<T> {
+impl<T> ControlChannelHandle<T>
+where
+    T: 'static + Transport,
+{
     // Create a control channel handle, where the control channel handling task
     // and the connection pool task are created.
+    #[instrument(skip_all, fields(service = %service.name))]
     fn new(conn: T::Stream, service: ServerServiceConfig) -> ControlChannelHandle<T> {
         // Save the name string for logging
         let name = service.name.clone();
 
         // Create a shutdown channel. The sender is not used for now, but for future use
-        let (_shutdown_tx, shutdown_rx) = oneshot::channel::<bool>();
+        let (shutdown_tx, shutdown_rx) = broadcast::channel::<bool>(1);
+
+        // Store data channels
+        let (data_ch_tx, data_ch_rx) = mpsc::channel(CHAN_SIZE * 2);
+
+        // Store data channel creation requests
+        let (data_ch_req_tx, data_ch_req_rx) = mpsc::unbounded_channel();
+
+        match service.service_type {
+            ServiceType::Tcp => tokio::spawn(
+                run_tcp_connection_pool::<T>(
+                    service.bind_addr.clone(),
+                    data_ch_rx,
+                    data_ch_req_tx.clone(),
+                    shutdown_tx.subscribe(),
+                )
+                .instrument(Span::current()),
+            ),
+            ServiceType::Udp => tokio::spawn(
+                run_udp_connection_pool::<T>(
+                    service.bind_addr.clone(),
+                    data_ch_rx,
+                    data_ch_req_tx.clone(),
+                    shutdown_tx.subscribe(),
+                )
+                .instrument(Span::current()),
+            ),
+        };
+
+        // Cache some data channels for later use
+        let pool_size = match service.service_type {
+            ServiceType::Tcp => TCP_POOL_SIZE,
+            ServiceType::Udp => UDP_POOL_SIZE,
+        };
 
-        // Create and run the connection pool, where the visitors and data channels meet
-        let conn_pool = ConnectionPoolHandle::new();
+        for _i in 0..pool_size {
+            if let Err(e) = data_ch_req_tx.send(true) {
+                error!("Failed to request data channel {}", e);
+            };
+        }
 
         // Create the control channel
-        let ch: ControlChannel<T> = ControlChannel {
+        let ch = ControlChannel::<T> {
             conn,
             shutdown_rx,
             service,
-            visitor_tx: conn_pool.visitor_tx.clone(),
+            data_ch_req_rx,
         };
 
         // Run the control channel
@@ -335,52 +370,83 @@ impl<T: 'static + Transport> ControlChannelHandle<T> {
         });
 
         ControlChannelHandle {
-            _shutdown_tx,
-            conn_pool,
+            shutdown_tx,
+            data_ch_tx,
         }
     }
+
+    #[allow(dead_code)]
+    fn shutdown(self) {
+        let _ = self.shutdown_tx.send(true);
+    }
+}
+
+// Control channel, using T as the transport layer. P is TcpStream or UdpTraffic
+struct ControlChannel<T: Transport> {
+    conn: T::Stream,                               // The connection of control channel
+    service: ServerServiceConfig,                  // A copy of the corresponding service config
+    shutdown_rx: broadcast::Receiver<bool>,        // Receives the shutdown signal
+    data_ch_req_rx: mpsc::UnboundedReceiver<bool>, // Receives visitor connections
 }
 
 impl<T: Transport> ControlChannel<T> {
     // Run a control channel
-    #[tracing::instrument(skip(self), fields(service = %self.service.name))]
+    #[instrument(skip(self), fields(service = %self.service.name))]
     async fn run(mut self) -> Result<()> {
-        // Where the service is exposed
-        let l = match TcpListener::bind(&self.service.bind_addr).await {
-            Ok(v) => v,
-            Err(e) => {
-                let duration = Duration::from_secs(1);
-                error!(
-                    "Failed to listen on service.bind_addr: {}. Retry in {:?}...",
-                    e, duration
-                );
-                time::sleep(duration).await;
-                TcpListener::bind(&self.service.bind_addr).await?
+        let cmd = bincode::serialize(&ControlChannelCmd::CreateDataChannel).unwrap();
+
+        // Wait for data channel requests and the shutdown signal
+        loop {
+            tokio::select! {
+                val = self.data_ch_req_rx.recv() => {
+                    match val {
+                        Some(_) => {
+                            if let Err(e) = self.conn.write_all(&cmd).await.with_context(||"Failed to write data cmds") {
+                                error!("{:?}", e);
+                                break;
+                            }
+                        }
+                        None => {
+                            break;
+                        }
+                    }
+                },
+                // Wait for the shutdown signal
+                _ = self.shutdown_rx.recv() => {
+                    break;
+                }
             }
-        };
+        }
+
+        info!("Control channel shuting down");
 
-        info!("Listening at {}", &self.service.bind_addr);
+        Ok(())
+    }
+}
 
-        // Each `u8` in the chan indicates a data channel creation request
-        let (data_req_tx, mut data_req_rx) = mpsc::unbounded_channel::<u8>();
+fn tcp_listen_and_send(
+    addr: String,
+    data_ch_req_tx: mpsc::UnboundedSender<bool>,
+    mut shutdown_rx: broadcast::Receiver<bool>,
+) -> mpsc::Receiver<TcpStream> {
+    let (tx, rx) = mpsc::channel(CHAN_SIZE);
 
-        // The control channel is moved into the task, and sends CreateDataChannel
-        // comamnds to the client when needed
-        tokio::spawn(async move {
-            let cmd = bincode::serialize(&ControlChannelCmd::CreateDataChannel).unwrap();
-            while data_req_rx.recv().await.is_some() {
-                if self.conn.write_all(&cmd).await.is_err() {
-                    break;
-                }
+    tokio::spawn(async move {
+        let l = backoff::future::retry(listen_backoff(), || async {
+            Ok(TcpListener::bind(&addr).await?)
+        })
+        .await
+        .with_context(|| "Failed to listen for the service");
+
+        let l: TcpListener = match l {
+            Ok(v) => v,
+            Err(e) => {
+                error!("{:?}", e);
+                return;
             }
-        });
+        };
 
-        // Cache some data channels for later use
-        for _i in 0..POOL_SIZE {
-            if let Err(e) = data_req_tx.send(0) {
-                error!("Failed to request data channel {}", e);
-            };
-        }
+        info!("Listening at {}", &addr);
 
         // Retry at least every 1s
         let mut backoff = ExponentialBackoff {
@@ -392,7 +458,6 @@ impl<T: Transport> ControlChannel<T> {
         // Wait for visitors and the shutdown signal
         loop {
             tokio::select! {
-                // Wait for visitors
                 val = l.accept() => {
                     match val {
                         Err(e) => {
@@ -406,80 +471,103 @@ impl<T: Transport> ControlChannel<T> {
                                 error!("Too many retries. Aborting...");
                                 break;
                             }
-                        },
+                        }
                         Ok((incoming, addr)) => {
                             // For every visitor, request to create a data channel
-                            if let Err(e) = data_req_tx.send(0) {
+                            if let Err(e) = data_ch_req_tx.send(true) {
                                 // An error indicates the control channel is broken
                                 // So break the loop
                                 error!("{}", e);
                                 break;
-                            };
+                            }
 
                             backoff.reset();
 
                             debug!("New visitor from {}", addr);
 
                             // Send the visitor to the connection pool
-                            let _ = self.visitor_tx.send(incoming).await;
+                            let _ = tx.send(incoming).await;
                         }
                     }
                 },
-                // Wait for the shutdown signal
-                _ = &mut self.shutdown_rx => {
+                _ = shutdown_rx.recv() => {
                     break;
                 }
             }
         }
-        info!("Service shuting down");
+    });
 
-        Ok(())
-    }
-}
-
-#[derive(Debug)]
-struct ConnectionPool<T: Transport> {
-    visitor_rx: mpsc::Receiver<TcpStream>,
-    data_ch_rx: mpsc::Receiver<T::Stream>,
-}
-
-struct ConnectionPoolHandle<T: Transport> {
-    visitor_tx: mpsc::Sender<TcpStream>,
-    data_ch_tx: mpsc::Sender<T::Stream>,
+    rx
 }
 
-impl<T: 'static + Transport> ConnectionPoolHandle<T> {
-    fn new() -> ConnectionPoolHandle<T> {
-        let (data_ch_tx, data_ch_rx) = mpsc::channel(CHAN_SIZE * 2);
-        let (visitor_tx, visitor_rx) = mpsc::channel(CHAN_SIZE);
-        let conn_pool: ConnectionPool<T> = ConnectionPool {
-            data_ch_rx,
-            visitor_rx,
-        };
-
-        tokio::spawn(async move { conn_pool.run().await });
-
-        ConnectionPoolHandle {
-            data_ch_tx,
-            visitor_tx,
+#[instrument(skip_all)]
+async fn run_tcp_connection_pool<T: Transport>(
+    bind_addr: String,
+    mut data_ch_rx: mpsc::Receiver<T::Stream>,
+    data_ch_req_tx: mpsc::UnboundedSender<bool>,
+    shutdown_rx: broadcast::Receiver<bool>,
+) -> Result<()> {
+    let mut visitor_rx = tcp_listen_and_send(bind_addr, data_ch_req_tx, shutdown_rx);
+    while let Some(mut visitor) = visitor_rx.recv().await {
+        if let Some(mut ch) = data_ch_rx.recv().await {
+            tokio::spawn(async move {
+                let cmd = bincode::serialize(&DataChannelCmd::StartForwardTcp).unwrap();
+                if ch.write_all(&cmd).await.is_ok() {
+                    let _ = copy_bidirectional(&mut ch, &mut visitor).await;
+                }
+            });
+        } else {
+            break;
         }
     }
+    Ok(())
 }
 
-impl<T: Transport> ConnectionPool<T> {
-    #[tracing::instrument]
-    async fn run(mut self) {
-        while let Some(mut visitor) = self.visitor_rx.recv().await {
-            if let Some(mut ch) = self.data_ch_rx.recv().await {
-                tokio::spawn(async move {
-                    let cmd = bincode::serialize(&DataChannelCmd::StartForward).unwrap();
-                    if ch.write_all(&cmd).await.is_ok() {
-                        let _ = copy_bidirectional(&mut ch, &mut visitor).await;
-                    }
-                });
-            } else {
+#[instrument(skip_all)]
+async fn run_udp_connection_pool<T: Transport>(
+    bind_addr: String,
+    mut data_ch_rx: mpsc::Receiver<T::Stream>,
+    _data_ch_req_tx: mpsc::UnboundedSender<bool>,
+    mut shutdown_rx: broadcast::Receiver<bool>,
+) -> Result<()> {
+    // TODO: Load balance
+
+    let l: UdpSocket = backoff::future::retry(listen_backoff(), || async {
+        Ok(UdpSocket::bind(&bind_addr).await?)
+    })
+    .await
+    .with_context(|| "Failed to listen for the service")?;
+
+    info!("Listening at {}", &bind_addr);
+
+    let cmd = bincode::serialize(&DataChannelCmd::StartForwardUdp).unwrap();
+
+    let mut conn = data_ch_rx
+        .recv()
+        .await
+        .ok_or(anyhow!("No available data channels"))?;
+    conn.write_all(&cmd).await?;
+
+    let mut buf = [0u8; UDP_BUFFER_SIZE];
+    loop {
+        tokio::select! {
+            // Forward inbound traffic to the client
+            val = l.recv_from(&mut buf) => {
+                let (n, from) = val?;
+                UdpTraffic::write_slice(&mut conn, from, &buf[..n]).await?;
+            },
+
+            // Forward outbound traffic from the client to the visitor
+            t = UdpTraffic::read(&mut conn) => {
+                let t = t?;
+                l.send_to(&t.data, t.from).await?;
+            },
+
+            _ = shutdown_rx.recv() => {
                 break;
             }
         }
     }
+
+    Ok(())
 }