|
|
@@ -1,8 +1,10 @@
|
|
|
-use crate::config::{Config, ServerConfig, ServerServiceConfig, TransportType};
|
|
|
+use crate::config::{Config, ServerConfig, ServerServiceConfig, ServiceType, TransportType};
|
|
|
+use crate::constants::{listen_backoff, UDP_BUFFER_SIZE};
|
|
|
use crate::multi_map::MultiMap;
|
|
|
use crate::protocol::Hello::{ControlChannelHello, DataChannelHello};
|
|
|
use crate::protocol::{
|
|
|
- self, read_auth, read_hello, Ack, ControlChannelCmd, DataChannelCmd, Hello, HASH_WIDTH_IN_BYTES,
|
|
|
+ self, read_auth, read_hello, Ack, ControlChannelCmd, DataChannelCmd, Hello, UdpTraffic,
|
|
|
+ HASH_WIDTH_IN_BYTES,
|
|
|
};
|
|
|
#[cfg(feature = "tls")]
|
|
|
use crate::transport::TlsTransport;
|
|
|
@@ -10,21 +12,23 @@ use crate::transport::{TcpTransport, Transport};
|
|
|
use anyhow::{anyhow, bail, Context, Result};
|
|
|
use backoff::backoff::Backoff;
|
|
|
use backoff::ExponentialBackoff;
|
|
|
+
|
|
|
use rand::RngCore;
|
|
|
use std::collections::HashMap;
|
|
|
use std::net::SocketAddr;
|
|
|
use std::sync::Arc;
|
|
|
use std::time::Duration;
|
|
|
use tokio::io::{self, copy_bidirectional, AsyncWriteExt};
|
|
|
-use tokio::net::{TcpListener, TcpStream};
|
|
|
-use tokio::sync::{broadcast, mpsc, oneshot, RwLock};
|
|
|
+use tokio::net::{TcpListener, TcpStream, UdpSocket};
|
|
|
+use tokio::sync::{broadcast, mpsc, RwLock};
|
|
|
use tokio::time;
|
|
|
-use tracing::{debug, error, info, info_span, warn, Instrument};
|
|
|
+use tracing::{debug, error, info, info_span, instrument, warn, Instrument, Span};
|
|
|
|
|
|
type ServiceDigest = protocol::Digest; // SHA256 of a service name
|
|
|
type Nonce = protocol::Digest; // Also called `session_key`
|
|
|
|
|
|
-const POOL_SIZE: usize = 64; // The number of cached connections
|
|
|
+const TCP_POOL_SIZE: usize = 64; // The number of cached connections for TCP servies
|
|
|
+const UDP_POOL_SIZE: usize = 2; // The number of cached connections for UDP services
|
|
|
const CHAN_SIZE: usize = 2048; // The capacity of various chans
|
|
|
|
|
|
// The entrypoint of running a server
|
|
|
@@ -268,7 +272,7 @@ async fn do_control_channel_handshake<T: 'static + Transport>(
|
|
|
Ok(())
|
|
|
}
|
|
|
|
|
|
-async fn do_data_channel_handshake<T: Transport>(
|
|
|
+async fn do_data_channel_handshake<T: 'static + Transport>(
|
|
|
conn: T::Stream,
|
|
|
control_channels: Arc<RwLock<ControlChannelMap<T>>>,
|
|
|
nonce: Nonce,
|
|
|
@@ -276,9 +280,9 @@ async fn do_data_channel_handshake<T: Transport>(
|
|
|
// Validate
|
|
|
let control_channels_guard = control_channels.read().await;
|
|
|
match control_channels_guard.get2(&nonce) {
|
|
|
- Some(c_ch) => {
|
|
|
+ Some(handle) => {
|
|
|
// Send the data channel to the corresponding control channel
|
|
|
- c_ch.conn_pool.data_ch_tx.send(conn).await?;
|
|
|
+ handle.data_ch_tx.send(conn).await?;
|
|
|
}
|
|
|
None => {
|
|
|
// TODO: Maybe print IP here
|
|
|
@@ -288,43 +292,74 @@ async fn do_data_channel_handshake<T: Transport>(
|
|
|
Ok(())
|
|
|
}
|
|
|
|
|
|
-// Control channel, using T as the transport layer
|
|
|
-struct ControlChannel<T: Transport> {
|
|
|
- conn: T::Stream, // The connection of control channel
|
|
|
- service: ServerServiceConfig, // A copy of the corresponding service config
|
|
|
- shutdown_rx: oneshot::Receiver<bool>, // Receives the shutdown signal
|
|
|
- visitor_tx: mpsc::Sender<TcpStream>, // Receives visitor connections
|
|
|
-}
|
|
|
-
|
|
|
-// The handle of a control channel, along with the handle of a connection pool
|
|
|
-// Dropping it will drop the actual control channel, because `visitor_tx`
|
|
|
-// and `shutdown_tx` are closed
|
|
|
-struct ControlChannelHandle<T: Transport> {
|
|
|
+pub struct ControlChannelHandle<T: Transport> {
|
|
|
// Shutdown the control channel.
|
|
|
// Not used for now, but can be used for hot reloading
|
|
|
- _shutdown_tx: oneshot::Sender<bool>,
|
|
|
- conn_pool: ConnectionPoolHandle<T>,
|
|
|
+ #[allow(dead_code)]
|
|
|
+ shutdown_tx: broadcast::Sender<bool>,
|
|
|
+ //data_ch_req_tx: mpsc::Sender<bool>,
|
|
|
+ data_ch_tx: mpsc::Sender<T::Stream>,
|
|
|
}
|
|
|
|
|
|
-impl<T: 'static + Transport> ControlChannelHandle<T> {
|
|
|
+impl<T> ControlChannelHandle<T>
|
|
|
+where
|
|
|
+ T: 'static + Transport,
|
|
|
+{
|
|
|
// Create a control channel handle, where the control channel handling task
|
|
|
// and the connection pool task are created.
|
|
|
+ #[instrument(skip_all, fields(service = %service.name))]
|
|
|
fn new(conn: T::Stream, service: ServerServiceConfig) -> ControlChannelHandle<T> {
|
|
|
// Save the name string for logging
|
|
|
let name = service.name.clone();
|
|
|
|
|
|
// Create a shutdown channel. The sender is not used for now, but for future use
|
|
|
- let (_shutdown_tx, shutdown_rx) = oneshot::channel::<bool>();
|
|
|
+ let (shutdown_tx, shutdown_rx) = broadcast::channel::<bool>(1);
|
|
|
+
|
|
|
+ // Store data channels
|
|
|
+ let (data_ch_tx, data_ch_rx) = mpsc::channel(CHAN_SIZE * 2);
|
|
|
+
|
|
|
+ // Store data channel creation requests
|
|
|
+ let (data_ch_req_tx, data_ch_req_rx) = mpsc::unbounded_channel();
|
|
|
+
|
|
|
+ match service.service_type {
|
|
|
+ ServiceType::Tcp => tokio::spawn(
|
|
|
+ run_tcp_connection_pool::<T>(
|
|
|
+ service.bind_addr.clone(),
|
|
|
+ data_ch_rx,
|
|
|
+ data_ch_req_tx.clone(),
|
|
|
+ shutdown_tx.subscribe(),
|
|
|
+ )
|
|
|
+ .instrument(Span::current()),
|
|
|
+ ),
|
|
|
+ ServiceType::Udp => tokio::spawn(
|
|
|
+ run_udp_connection_pool::<T>(
|
|
|
+ service.bind_addr.clone(),
|
|
|
+ data_ch_rx,
|
|
|
+ data_ch_req_tx.clone(),
|
|
|
+ shutdown_tx.subscribe(),
|
|
|
+ )
|
|
|
+ .instrument(Span::current()),
|
|
|
+ ),
|
|
|
+ };
|
|
|
+
|
|
|
+ // Cache some data channels for later use
|
|
|
+ let pool_size = match service.service_type {
|
|
|
+ ServiceType::Tcp => TCP_POOL_SIZE,
|
|
|
+ ServiceType::Udp => UDP_POOL_SIZE,
|
|
|
+ };
|
|
|
|
|
|
- // Create and run the connection pool, where the visitors and data channels meet
|
|
|
- let conn_pool = ConnectionPoolHandle::new();
|
|
|
+ for _i in 0..pool_size {
|
|
|
+ if let Err(e) = data_ch_req_tx.send(true) {
|
|
|
+ error!("Failed to request data channel {}", e);
|
|
|
+ };
|
|
|
+ }
|
|
|
|
|
|
// Create the control channel
|
|
|
- let ch: ControlChannel<T> = ControlChannel {
|
|
|
+ let ch = ControlChannel::<T> {
|
|
|
conn,
|
|
|
shutdown_rx,
|
|
|
service,
|
|
|
- visitor_tx: conn_pool.visitor_tx.clone(),
|
|
|
+ data_ch_req_rx,
|
|
|
};
|
|
|
|
|
|
// Run the control channel
|
|
|
@@ -335,52 +370,83 @@ impl<T: 'static + Transport> ControlChannelHandle<T> {
|
|
|
});
|
|
|
|
|
|
ControlChannelHandle {
|
|
|
- _shutdown_tx,
|
|
|
- conn_pool,
|
|
|
+ shutdown_tx,
|
|
|
+ data_ch_tx,
|
|
|
}
|
|
|
}
|
|
|
+
|
|
|
+ #[allow(dead_code)]
|
|
|
+ fn shutdown(self) {
|
|
|
+ let _ = self.shutdown_tx.send(true);
|
|
|
+ }
|
|
|
+}
|
|
|
+
|
|
|
+// Control channel, using T as the transport layer. P is TcpStream or UdpTraffic
|
|
|
+struct ControlChannel<T: Transport> {
|
|
|
+ conn: T::Stream, // The connection of control channel
|
|
|
+ service: ServerServiceConfig, // A copy of the corresponding service config
|
|
|
+ shutdown_rx: broadcast::Receiver<bool>, // Receives the shutdown signal
|
|
|
+ data_ch_req_rx: mpsc::UnboundedReceiver<bool>, // Receives visitor connections
|
|
|
}
|
|
|
|
|
|
impl<T: Transport> ControlChannel<T> {
|
|
|
// Run a control channel
|
|
|
- #[tracing::instrument(skip(self), fields(service = %self.service.name))]
|
|
|
+ #[instrument(skip(self), fields(service = %self.service.name))]
|
|
|
async fn run(mut self) -> Result<()> {
|
|
|
- // Where the service is exposed
|
|
|
- let l = match TcpListener::bind(&self.service.bind_addr).await {
|
|
|
- Ok(v) => v,
|
|
|
- Err(e) => {
|
|
|
- let duration = Duration::from_secs(1);
|
|
|
- error!(
|
|
|
- "Failed to listen on service.bind_addr: {}. Retry in {:?}...",
|
|
|
- e, duration
|
|
|
- );
|
|
|
- time::sleep(duration).await;
|
|
|
- TcpListener::bind(&self.service.bind_addr).await?
|
|
|
+ let cmd = bincode::serialize(&ControlChannelCmd::CreateDataChannel).unwrap();
|
|
|
+
|
|
|
+ // Wait for data channel requests and the shutdown signal
|
|
|
+ loop {
|
|
|
+ tokio::select! {
|
|
|
+ val = self.data_ch_req_rx.recv() => {
|
|
|
+ match val {
|
|
|
+ Some(_) => {
|
|
|
+ if let Err(e) = self.conn.write_all(&cmd).await.with_context(||"Failed to write data cmds") {
|
|
|
+ error!("{:?}", e);
|
|
|
+ break;
|
|
|
+ }
|
|
|
+ }
|
|
|
+ None => {
|
|
|
+ break;
|
|
|
+ }
|
|
|
+ }
|
|
|
+ },
|
|
|
+ // Wait for the shutdown signal
|
|
|
+ _ = self.shutdown_rx.recv() => {
|
|
|
+ break;
|
|
|
+ }
|
|
|
}
|
|
|
- };
|
|
|
+ }
|
|
|
+
|
|
|
+ info!("Control channel shuting down");
|
|
|
|
|
|
- info!("Listening at {}", &self.service.bind_addr);
|
|
|
+ Ok(())
|
|
|
+ }
|
|
|
+}
|
|
|
|
|
|
- // Each `u8` in the chan indicates a data channel creation request
|
|
|
- let (data_req_tx, mut data_req_rx) = mpsc::unbounded_channel::<u8>();
|
|
|
+fn tcp_listen_and_send(
|
|
|
+ addr: String,
|
|
|
+ data_ch_req_tx: mpsc::UnboundedSender<bool>,
|
|
|
+ mut shutdown_rx: broadcast::Receiver<bool>,
|
|
|
+) -> mpsc::Receiver<TcpStream> {
|
|
|
+ let (tx, rx) = mpsc::channel(CHAN_SIZE);
|
|
|
|
|
|
- // The control channel is moved into the task, and sends CreateDataChannel
|
|
|
- // comamnds to the client when needed
|
|
|
- tokio::spawn(async move {
|
|
|
- let cmd = bincode::serialize(&ControlChannelCmd::CreateDataChannel).unwrap();
|
|
|
- while data_req_rx.recv().await.is_some() {
|
|
|
- if self.conn.write_all(&cmd).await.is_err() {
|
|
|
- break;
|
|
|
- }
|
|
|
+ tokio::spawn(async move {
|
|
|
+ let l = backoff::future::retry(listen_backoff(), || async {
|
|
|
+ Ok(TcpListener::bind(&addr).await?)
|
|
|
+ })
|
|
|
+ .await
|
|
|
+ .with_context(|| "Failed to listen for the service");
|
|
|
+
|
|
|
+ let l: TcpListener = match l {
|
|
|
+ Ok(v) => v,
|
|
|
+ Err(e) => {
|
|
|
+ error!("{:?}", e);
|
|
|
+ return;
|
|
|
}
|
|
|
- });
|
|
|
+ };
|
|
|
|
|
|
- // Cache some data channels for later use
|
|
|
- for _i in 0..POOL_SIZE {
|
|
|
- if let Err(e) = data_req_tx.send(0) {
|
|
|
- error!("Failed to request data channel {}", e);
|
|
|
- };
|
|
|
- }
|
|
|
+ info!("Listening at {}", &addr);
|
|
|
|
|
|
// Retry at least every 1s
|
|
|
let mut backoff = ExponentialBackoff {
|
|
|
@@ -392,7 +458,6 @@ impl<T: Transport> ControlChannel<T> {
|
|
|
// Wait for visitors and the shutdown signal
|
|
|
loop {
|
|
|
tokio::select! {
|
|
|
- // Wait for visitors
|
|
|
val = l.accept() => {
|
|
|
match val {
|
|
|
Err(e) => {
|
|
|
@@ -406,80 +471,103 @@ impl<T: Transport> ControlChannel<T> {
|
|
|
error!("Too many retries. Aborting...");
|
|
|
break;
|
|
|
}
|
|
|
- },
|
|
|
+ }
|
|
|
Ok((incoming, addr)) => {
|
|
|
// For every visitor, request to create a data channel
|
|
|
- if let Err(e) = data_req_tx.send(0) {
|
|
|
+ if let Err(e) = data_ch_req_tx.send(true) {
|
|
|
// An error indicates the control channel is broken
|
|
|
// So break the loop
|
|
|
error!("{}", e);
|
|
|
break;
|
|
|
- };
|
|
|
+ }
|
|
|
|
|
|
backoff.reset();
|
|
|
|
|
|
debug!("New visitor from {}", addr);
|
|
|
|
|
|
// Send the visitor to the connection pool
|
|
|
- let _ = self.visitor_tx.send(incoming).await;
|
|
|
+ let _ = tx.send(incoming).await;
|
|
|
}
|
|
|
}
|
|
|
},
|
|
|
- // Wait for the shutdown signal
|
|
|
- _ = &mut self.shutdown_rx => {
|
|
|
+ _ = shutdown_rx.recv() => {
|
|
|
break;
|
|
|
}
|
|
|
}
|
|
|
}
|
|
|
- info!("Service shuting down");
|
|
|
+ });
|
|
|
|
|
|
- Ok(())
|
|
|
- }
|
|
|
-}
|
|
|
-
|
|
|
-#[derive(Debug)]
|
|
|
-struct ConnectionPool<T: Transport> {
|
|
|
- visitor_rx: mpsc::Receiver<TcpStream>,
|
|
|
- data_ch_rx: mpsc::Receiver<T::Stream>,
|
|
|
-}
|
|
|
-
|
|
|
-struct ConnectionPoolHandle<T: Transport> {
|
|
|
- visitor_tx: mpsc::Sender<TcpStream>,
|
|
|
- data_ch_tx: mpsc::Sender<T::Stream>,
|
|
|
+ rx
|
|
|
}
|
|
|
|
|
|
-impl<T: 'static + Transport> ConnectionPoolHandle<T> {
|
|
|
- fn new() -> ConnectionPoolHandle<T> {
|
|
|
- let (data_ch_tx, data_ch_rx) = mpsc::channel(CHAN_SIZE * 2);
|
|
|
- let (visitor_tx, visitor_rx) = mpsc::channel(CHAN_SIZE);
|
|
|
- let conn_pool: ConnectionPool<T> = ConnectionPool {
|
|
|
- data_ch_rx,
|
|
|
- visitor_rx,
|
|
|
- };
|
|
|
-
|
|
|
- tokio::spawn(async move { conn_pool.run().await });
|
|
|
-
|
|
|
- ConnectionPoolHandle {
|
|
|
- data_ch_tx,
|
|
|
- visitor_tx,
|
|
|
+#[instrument(skip_all)]
|
|
|
+async fn run_tcp_connection_pool<T: Transport>(
|
|
|
+ bind_addr: String,
|
|
|
+ mut data_ch_rx: mpsc::Receiver<T::Stream>,
|
|
|
+ data_ch_req_tx: mpsc::UnboundedSender<bool>,
|
|
|
+ shutdown_rx: broadcast::Receiver<bool>,
|
|
|
+) -> Result<()> {
|
|
|
+ let mut visitor_rx = tcp_listen_and_send(bind_addr, data_ch_req_tx, shutdown_rx);
|
|
|
+ while let Some(mut visitor) = visitor_rx.recv().await {
|
|
|
+ if let Some(mut ch) = data_ch_rx.recv().await {
|
|
|
+ tokio::spawn(async move {
|
|
|
+ let cmd = bincode::serialize(&DataChannelCmd::StartForwardTcp).unwrap();
|
|
|
+ if ch.write_all(&cmd).await.is_ok() {
|
|
|
+ let _ = copy_bidirectional(&mut ch, &mut visitor).await;
|
|
|
+ }
|
|
|
+ });
|
|
|
+ } else {
|
|
|
+ break;
|
|
|
}
|
|
|
}
|
|
|
+ Ok(())
|
|
|
}
|
|
|
|
|
|
-impl<T: Transport> ConnectionPool<T> {
|
|
|
- #[tracing::instrument]
|
|
|
- async fn run(mut self) {
|
|
|
- while let Some(mut visitor) = self.visitor_rx.recv().await {
|
|
|
- if let Some(mut ch) = self.data_ch_rx.recv().await {
|
|
|
- tokio::spawn(async move {
|
|
|
- let cmd = bincode::serialize(&DataChannelCmd::StartForward).unwrap();
|
|
|
- if ch.write_all(&cmd).await.is_ok() {
|
|
|
- let _ = copy_bidirectional(&mut ch, &mut visitor).await;
|
|
|
- }
|
|
|
- });
|
|
|
- } else {
|
|
|
+#[instrument(skip_all)]
|
|
|
+async fn run_udp_connection_pool<T: Transport>(
|
|
|
+ bind_addr: String,
|
|
|
+ mut data_ch_rx: mpsc::Receiver<T::Stream>,
|
|
|
+ _data_ch_req_tx: mpsc::UnboundedSender<bool>,
|
|
|
+ mut shutdown_rx: broadcast::Receiver<bool>,
|
|
|
+) -> Result<()> {
|
|
|
+ // TODO: Load balance
|
|
|
+
|
|
|
+ let l: UdpSocket = backoff::future::retry(listen_backoff(), || async {
|
|
|
+ Ok(UdpSocket::bind(&bind_addr).await?)
|
|
|
+ })
|
|
|
+ .await
|
|
|
+ .with_context(|| "Failed to listen for the service")?;
|
|
|
+
|
|
|
+ info!("Listening at {}", &bind_addr);
|
|
|
+
|
|
|
+ let cmd = bincode::serialize(&DataChannelCmd::StartForwardUdp).unwrap();
|
|
|
+
|
|
|
+ let mut conn = data_ch_rx
|
|
|
+ .recv()
|
|
|
+ .await
|
|
|
+ .ok_or(anyhow!("No available data channels"))?;
|
|
|
+ conn.write_all(&cmd).await?;
|
|
|
+
|
|
|
+ let mut buf = [0u8; UDP_BUFFER_SIZE];
|
|
|
+ loop {
|
|
|
+ tokio::select! {
|
|
|
+ // Forward inbound traffic to the client
|
|
|
+ val = l.recv_from(&mut buf) => {
|
|
|
+ let (n, from) = val?;
|
|
|
+ UdpTraffic::write_slice(&mut conn, from, &buf[..n]).await?;
|
|
|
+ },
|
|
|
+
|
|
|
+ // Forward outbound traffic from the client to the visitor
|
|
|
+ t = UdpTraffic::read(&mut conn) => {
|
|
|
+ let t = t?;
|
|
|
+ l.send_to(&t.data, t.from).await?;
|
|
|
+ },
|
|
|
+
|
|
|
+ _ = shutdown_rx.recv() => {
|
|
|
break;
|
|
|
}
|
|
|
}
|
|
|
}
|
|
|
+
|
|
|
+ Ok(())
|
|
|
}
|