|
|
@@ -1,29 +1,25 @@
|
|
|
-use std::collections::HashMap;
|
|
|
-use std::net::SocketAddr;
|
|
|
-use std::sync::Arc;
|
|
|
-use std::time::Duration;
|
|
|
-
|
|
|
-use crate::config::{Config, ServerConfig, ServerServiceConfig};
|
|
|
-use crate::helper::set_tcp_keepalive;
|
|
|
+use crate::config::{Config, ServerConfig, ServerServiceConfig, TransportType};
|
|
|
use crate::multi_map::MultiMap;
|
|
|
use crate::protocol::{
|
|
|
- self, read_hello, Hello, Hello::ControlChannelHello, Hello::DataChannelHello,
|
|
|
+ self, Ack, ControlChannelCmd, DataChannelCmd, Hello, Hello::ControlChannelHello,
|
|
|
+ Hello::DataChannelHello, HASH_WIDTH_IN_BYTES,
|
|
|
};
|
|
|
-use crate::protocol::{read_auth, Ack, ControlChannelCmd, DataChannelCmd, HASH_WIDTH_IN_BYTES};
|
|
|
+use crate::protocol::{read_auth, read_hello};
|
|
|
+use crate::transport::{TcpTransport, TlsTransport, Transport};
|
|
|
use anyhow::{anyhow, bail, Context, Result};
|
|
|
+use backoff::{backoff::Backoff, ExponentialBackoff};
|
|
|
use rand::RngCore;
|
|
|
-use tokio::io::{self, AsyncWriteExt};
|
|
|
+use std::collections::HashMap;
|
|
|
+use std::net::SocketAddr;
|
|
|
+use std::sync::Arc;
|
|
|
+use std::time::Duration;
|
|
|
+use tokio::io::{self, copy_bidirectional, AsyncWriteExt};
|
|
|
+use tokio::net::{TcpListener, TcpStream};
|
|
|
use tokio::sync::mpsc;
|
|
|
use tokio::sync::{oneshot, RwLock};
|
|
|
use tokio::time;
|
|
|
-use tokio::{
|
|
|
- self,
|
|
|
- net::{self, TcpListener, TcpStream},
|
|
|
-};
|
|
|
use tracing::{debug, error, info, info_span, warn, Instrument};
|
|
|
|
|
|
-use backoff::{backoff::Backoff, ExponentialBackoff};
|
|
|
-
|
|
|
type ServiceDigest = protocol::Digest;
|
|
|
type Nonce = protocol::Digest;
|
|
|
|
|
|
@@ -31,43 +27,57 @@ const POOL_SIZE: usize = 64;
|
|
|
const CHAN_SIZE: usize = 2048;
|
|
|
|
|
|
pub async fn run_server(config: &Config) -> Result<()> {
|
|
|
- let mut server = Server::from(config)?;
|
|
|
-
|
|
|
- server.run().await
|
|
|
+ let config = match &config.server {
|
|
|
+ Some(config) => config,
|
|
|
+ None => {
|
|
|
+ return Err(anyhow!("Try to run as a server, but the configuration is missing. Please add the `[server]` block"))
|
|
|
+ }
|
|
|
+ };
|
|
|
+ match config.transport.transport_type {
|
|
|
+ TransportType::Tcp => {
|
|
|
+ let mut server = Server::<TcpTransport>::from(config).await?;
|
|
|
+ server.run().await?;
|
|
|
+ }
|
|
|
+ TransportType::Tls => {
|
|
|
+ let mut server = Server::<TlsTransport>::from(config).await?;
|
|
|
+ server.run().await?;
|
|
|
+ }
|
|
|
+ }
|
|
|
+ Ok(())
|
|
|
}
|
|
|
|
|
|
-type ControlChannelMap = MultiMap<ServiceDigest, Nonce, ControlChannelHandle>;
|
|
|
-struct Server<'a> {
|
|
|
+type ControlChannelMap<T> = MultiMap<ServiceDigest, Nonce, ControlChannelHandle<T>>;
|
|
|
+struct Server<'a, T: Transport> {
|
|
|
config: &'a ServerConfig,
|
|
|
services: Arc<RwLock<HashMap<ServiceDigest, ServerServiceConfig>>>,
|
|
|
- control_channels: Arc<RwLock<ControlChannelMap>>,
|
|
|
+ control_channels: Arc<RwLock<ControlChannelMap<T>>>,
|
|
|
+ transport: Arc<T>,
|
|
|
}
|
|
|
|
|
|
-impl<'a> Server<'a> {
|
|
|
- pub fn from(config: &'a Config) -> Result<Server> {
|
|
|
- match &config.server {
|
|
|
- Some(config) => Ok(Server {
|
|
|
- config,
|
|
|
- services: Arc::new(RwLock::new(Server::generate_service_hashmap(config))),
|
|
|
- control_channels: Arc::new(RwLock::new(ControlChannelMap::new())),
|
|
|
- }),
|
|
|
- None =>
|
|
|
- Err(anyhow!("Try to run as a server, but the configuration is missing. Please add the `[server]` block"))
|
|
|
- }
|
|
|
+fn generate_service_hashmap(
|
|
|
+ server_config: &ServerConfig,
|
|
|
+) -> HashMap<ServiceDigest, ServerServiceConfig> {
|
|
|
+ let mut ret = HashMap::new();
|
|
|
+ for u in &server_config.services {
|
|
|
+ ret.insert(protocol::digest(u.0.as_bytes()), (*u.1).clone());
|
|
|
}
|
|
|
+ ret
|
|
|
+}
|
|
|
|
|
|
- fn generate_service_hashmap(
|
|
|
- server_config: &ServerConfig,
|
|
|
- ) -> HashMap<ServiceDigest, ServerServiceConfig> {
|
|
|
- let mut ret = HashMap::new();
|
|
|
- for u in &server_config.services {
|
|
|
- ret.insert(protocol::digest(u.0.as_bytes()), (*u.1).clone());
|
|
|
- }
|
|
|
- ret
|
|
|
+impl<'a, T: 'static + Transport> Server<'a, T> {
|
|
|
+ pub async fn from(config: &'a ServerConfig) -> Result<Server<'a, T>> {
|
|
|
+ Ok(Server {
|
|
|
+ config,
|
|
|
+ services: Arc::new(RwLock::new(generate_service_hashmap(config))),
|
|
|
+ control_channels: Arc::new(RwLock::new(ControlChannelMap::new())),
|
|
|
+ transport: Arc::new(*(T::new(&config.transport).await?)),
|
|
|
+ })
|
|
|
}
|
|
|
|
|
|
pub async fn run(&mut self) -> Result<()> {
|
|
|
- let l = net::TcpListener::bind(&self.config.bind_addr)
|
|
|
+ let l = self
|
|
|
+ .transport
|
|
|
+ .bind(&self.config.bind_addr)
|
|
|
.await
|
|
|
.with_context(|| "Failed to listen at `server.bind_addr`")?;
|
|
|
info!("Listening at {}", self.config.bind_addr);
|
|
|
@@ -82,22 +92,25 @@ impl<'a> Server<'a> {
|
|
|
// Listen for incoming control or data channels
|
|
|
loop {
|
|
|
tokio::select! {
|
|
|
- ret = l.accept() => {
|
|
|
+ ret = self.transport.accept(&l) => {
|
|
|
match ret {
|
|
|
Err(err) => {
|
|
|
- // Possibly a EMFILE. So sleep for a while and retry
|
|
|
- if let Some(d) = backoff.next_backoff() {
|
|
|
- error!("Failed to accept: {}. Retry in {:?}...", err, d);
|
|
|
- time::sleep(d).await;
|
|
|
- } else {
|
|
|
- // This branch will never be executed according to the current retry policy
|
|
|
- error!("Too many retries. Aborting...");
|
|
|
- break;
|
|
|
+ if let Some(err) = err.downcast_ref::<io::Error>() {
|
|
|
+ // Possibly a EMFILE. So sleep for a while and retry
|
|
|
+ if let Some(d) = backoff.next_backoff() {
|
|
|
+ error!("Failed to accept: {}. Retry in {:?}...", err, d);
|
|
|
+ time::sleep(d).await;
|
|
|
+ } else {
|
|
|
+ // This branch will never be executed according to the current retry policy
|
|
|
+ error!("Too many retries. Aborting...");
|
|
|
+ break;
|
|
|
+ }
|
|
|
}
|
|
|
}
|
|
|
Ok((conn, addr)) => {
|
|
|
backoff.reset();
|
|
|
debug!("Incomming connection from {}", addr);
|
|
|
+
|
|
|
let services = self.services.clone();
|
|
|
let control_channels = self.control_channels.clone();
|
|
|
tokio::spawn(async move {
|
|
|
@@ -119,11 +132,11 @@ impl<'a> Server<'a> {
|
|
|
}
|
|
|
}
|
|
|
|
|
|
-async fn handle_connection(
|
|
|
- mut conn: TcpStream,
|
|
|
+async fn handle_connection<T: 'static + Transport>(
|
|
|
+ mut conn: T::Stream,
|
|
|
addr: SocketAddr,
|
|
|
services: Arc<RwLock<HashMap<ServiceDigest, ServerServiceConfig>>>,
|
|
|
- control_channels: Arc<RwLock<ControlChannelMap>>,
|
|
|
+ control_channels: Arc<RwLock<ControlChannelMap<T>>>,
|
|
|
) -> Result<()> {
|
|
|
// Read hello
|
|
|
let hello = read_hello(&mut conn).await?;
|
|
|
@@ -139,11 +152,11 @@ async fn handle_connection(
|
|
|
Ok(())
|
|
|
}
|
|
|
|
|
|
-async fn do_control_channel_handshake(
|
|
|
- mut conn: TcpStream,
|
|
|
+async fn do_control_channel_handshake<T: 'static + Transport>(
|
|
|
+ mut conn: T::Stream,
|
|
|
addr: SocketAddr,
|
|
|
services: Arc<RwLock<HashMap<ServiceDigest, ServerServiceConfig>>>,
|
|
|
- control_channels: Arc<RwLock<ControlChannelMap>>,
|
|
|
+ control_channels: Arc<RwLock<ControlChannelMap<T>>>,
|
|
|
service_digest: ServiceDigest,
|
|
|
) -> Result<()> {
|
|
|
info!("New control channel incomming from {}", addr);
|
|
|
@@ -219,19 +232,15 @@ async fn do_control_channel_handshake(
|
|
|
Ok(())
|
|
|
}
|
|
|
|
|
|
-async fn do_data_channel_handshake(
|
|
|
- conn: TcpStream,
|
|
|
- control_channels: Arc<RwLock<ControlChannelMap>>,
|
|
|
+async fn do_data_channel_handshake<T: Transport>(
|
|
|
+ conn: T::Stream,
|
|
|
+ control_channels: Arc<RwLock<ControlChannelMap<T>>>,
|
|
|
nonce: Nonce,
|
|
|
) -> Result<()> {
|
|
|
// Validate
|
|
|
let control_channels_guard = control_channels.read().await;
|
|
|
match control_channels_guard.get2(&nonce) {
|
|
|
Some(c_ch) => {
|
|
|
- if let Err(e) = set_tcp_keepalive(&conn) {
|
|
|
- error!("The connection may be unstable! {:?}", e);
|
|
|
- }
|
|
|
-
|
|
|
// Send the data channel to the corresponding control channel
|
|
|
c_ch.conn_pool.data_ch_tx.send(conn).await?;
|
|
|
}
|
|
|
@@ -242,24 +251,24 @@ async fn do_data_channel_handshake(
|
|
|
Ok(())
|
|
|
}
|
|
|
|
|
|
-struct ControlChannel {
|
|
|
- conn: TcpStream,
|
|
|
+struct ControlChannel<T: Transport> {
|
|
|
+ conn: T::Stream,
|
|
|
service: ServerServiceConfig,
|
|
|
shutdown_rx: oneshot::Receiver<bool>,
|
|
|
visitor_tx: mpsc::Sender<TcpStream>,
|
|
|
}
|
|
|
|
|
|
-struct ControlChannelHandle {
|
|
|
+struct ControlChannelHandle<T: Transport> {
|
|
|
shutdown_tx: oneshot::Sender<bool>,
|
|
|
- conn_pool: ConnectionPoolHandle,
|
|
|
+ conn_pool: ConnectionPoolHandle<T>,
|
|
|
}
|
|
|
|
|
|
-impl ControlChannelHandle {
|
|
|
- fn new(conn: TcpStream, service: ServerServiceConfig) -> ControlChannelHandle {
|
|
|
+impl<T: 'static + Transport> ControlChannelHandle<T> {
|
|
|
+ fn new(conn: T::Stream, service: ServerServiceConfig) -> ControlChannelHandle<T> {
|
|
|
let (shutdown_tx, shutdown_rx) = oneshot::channel::<bool>();
|
|
|
let name = service.name.clone();
|
|
|
let conn_pool = ConnectionPoolHandle::new();
|
|
|
- let actor = ControlChannel {
|
|
|
+ let actor: ControlChannel<T> = ControlChannel {
|
|
|
conn,
|
|
|
shutdown_rx,
|
|
|
service,
|
|
|
@@ -279,13 +288,9 @@ impl ControlChannelHandle {
|
|
|
}
|
|
|
}
|
|
|
|
|
|
-impl ControlChannel {
|
|
|
+impl<T: Transport> ControlChannel<T> {
|
|
|
#[tracing::instrument(skip(self), fields(service = %self.service.name))]
|
|
|
async fn run(mut self) -> Result<()> {
|
|
|
- if let Err(e) = set_tcp_keepalive(&self.conn) {
|
|
|
- error!("The connection may be unstable! {:?}", e);
|
|
|
- }
|
|
|
-
|
|
|
let l = match TcpListener::bind(&self.service.bind_addr).await {
|
|
|
Ok(v) => v,
|
|
|
Err(e) => {
|
|
|
@@ -360,21 +365,21 @@ impl ControlChannel {
|
|
|
}
|
|
|
|
|
|
#[derive(Debug)]
|
|
|
-struct ConnectionPool {
|
|
|
+struct ConnectionPool<T: Transport> {
|
|
|
visitor_rx: mpsc::Receiver<TcpStream>,
|
|
|
- data_ch_rx: mpsc::Receiver<TcpStream>,
|
|
|
+ data_ch_rx: mpsc::Receiver<T::Stream>,
|
|
|
}
|
|
|
|
|
|
-struct ConnectionPoolHandle {
|
|
|
+struct ConnectionPoolHandle<T: Transport> {
|
|
|
visitor_tx: mpsc::Sender<TcpStream>,
|
|
|
- data_ch_tx: mpsc::Sender<TcpStream>,
|
|
|
+ data_ch_tx: mpsc::Sender<T::Stream>,
|
|
|
}
|
|
|
|
|
|
-impl ConnectionPoolHandle {
|
|
|
- fn new() -> ConnectionPoolHandle {
|
|
|
+impl<T: 'static + Transport> ConnectionPoolHandle<T> {
|
|
|
+ fn new() -> ConnectionPoolHandle<T> {
|
|
|
let (data_ch_tx, data_ch_rx) = mpsc::channel(CHAN_SIZE * 2);
|
|
|
let (visitor_tx, visitor_rx) = mpsc::channel(CHAN_SIZE);
|
|
|
- let conn_pool = ConnectionPool {
|
|
|
+ let conn_pool: ConnectionPool<T> = ConnectionPool {
|
|
|
data_ch_rx,
|
|
|
visitor_rx,
|
|
|
};
|
|
|
@@ -388,7 +393,7 @@ impl ConnectionPoolHandle {
|
|
|
}
|
|
|
}
|
|
|
|
|
|
-impl ConnectionPool {
|
|
|
+impl<T: Transport> ConnectionPool<T> {
|
|
|
#[tracing::instrument]
|
|
|
async fn run(mut self) {
|
|
|
loop {
|
|
|
@@ -397,7 +402,7 @@ impl ConnectionPool {
|
|
|
tokio::spawn(async move {
|
|
|
let cmd = bincode::serialize(&DataChannelCmd::StartForward).unwrap();
|
|
|
if ch.write_all(&cmd).await.is_ok() {
|
|
|
- let _ = io::copy_bidirectional(&mut ch, &mut visitor).await;
|
|
|
+ let _ = copy_bidirectional(&mut ch, &mut visitor).await;
|
|
|
}
|
|
|
});
|
|
|
} else {
|