server.rs 14 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410
  1. use crate::config::{Config, ServerConfig, ServerServiceConfig, TransportType};
  2. use crate::multi_map::MultiMap;
  3. use crate::protocol::Hello::{ControlChannelHello, DataChannelHello};
  4. use crate::protocol::{
  5. self, read_auth, read_hello, Ack, ControlChannelCmd, DataChannelCmd, Hello, HASH_WIDTH_IN_BYTES,
  6. };
  7. use crate::transport::{TcpTransport, TlsTransport, Transport};
  8. use anyhow::{anyhow, bail, Context, Result};
  9. use backoff::backoff::Backoff;
  10. use backoff::ExponentialBackoff;
  11. use rand::RngCore;
  12. use std::collections::HashMap;
  13. use std::net::SocketAddr;
  14. use std::sync::Arc;
  15. use std::time::Duration;
  16. use tokio::io::{self, copy_bidirectional, AsyncWriteExt};
  17. use tokio::net::{TcpListener, TcpStream};
  18. use tokio::sync::{mpsc, oneshot, RwLock};
  19. use tokio::time;
  20. use tracing::{debug, error, info, info_span, warn, Instrument};
  21. type ServiceDigest = protocol::Digest;
  22. type Nonce = protocol::Digest;
  23. const POOL_SIZE: usize = 64;
  24. const CHAN_SIZE: usize = 2048;
  25. pub async fn run_server(config: &Config) -> Result<()> {
  26. let config = match &config.server {
  27. Some(config) => config,
  28. None => {
  29. return Err(anyhow!("Try to run as a server, but the configuration is missing. Please add the `[server]` block"))
  30. }
  31. };
  32. match config.transport.transport_type {
  33. TransportType::Tcp => {
  34. let mut server = Server::<TcpTransport>::from(config).await?;
  35. server.run().await?;
  36. }
  37. TransportType::Tls => {
  38. let mut server = Server::<TlsTransport>::from(config).await?;
  39. server.run().await?;
  40. }
  41. }
  42. Ok(())
  43. }
  44. type ControlChannelMap<T> = MultiMap<ServiceDigest, Nonce, ControlChannelHandle<T>>;
  45. struct Server<'a, T: Transport> {
  46. config: &'a ServerConfig,
  47. services: Arc<RwLock<HashMap<ServiceDigest, ServerServiceConfig>>>,
  48. control_channels: Arc<RwLock<ControlChannelMap<T>>>,
  49. transport: Arc<T>,
  50. }
  51. fn generate_service_hashmap(
  52. server_config: &ServerConfig,
  53. ) -> HashMap<ServiceDigest, ServerServiceConfig> {
  54. let mut ret = HashMap::new();
  55. for u in &server_config.services {
  56. ret.insert(protocol::digest(u.0.as_bytes()), (*u.1).clone());
  57. }
  58. ret
  59. }
  60. impl<'a, T: 'static + Transport> Server<'a, T> {
  61. pub async fn from(config: &'a ServerConfig) -> Result<Server<'a, T>> {
  62. Ok(Server {
  63. config,
  64. services: Arc::new(RwLock::new(generate_service_hashmap(config))),
  65. control_channels: Arc::new(RwLock::new(ControlChannelMap::new())),
  66. transport: Arc::new(*(T::new(&config.transport).await?)),
  67. })
  68. }
  69. pub async fn run(&mut self) -> Result<()> {
  70. let l = self
  71. .transport
  72. .bind(&self.config.bind_addr)
  73. .await
  74. .with_context(|| "Failed to listen at `server.bind_addr`")?;
  75. info!("Listening at {}", self.config.bind_addr);
  76. // Retry at least every 100ms
  77. let mut backoff = ExponentialBackoff {
  78. max_interval: Duration::from_millis(100),
  79. max_elapsed_time: None,
  80. ..Default::default()
  81. };
  82. // Listen for incoming control or data channels
  83. loop {
  84. tokio::select! {
  85. ret = self.transport.accept(&l) => {
  86. match ret {
  87. Err(err) => {
  88. if let Some(err) = err.downcast_ref::<io::Error>() {
  89. // Possibly a EMFILE. So sleep for a while and retry
  90. if let Some(d) = backoff.next_backoff() {
  91. error!("Failed to accept: {}. Retry in {:?}...", err, d);
  92. time::sleep(d).await;
  93. } else {
  94. // This branch will never be executed according to the current retry policy
  95. error!("Too many retries. Aborting...");
  96. break;
  97. }
  98. }
  99. }
  100. Ok((conn, addr)) => {
  101. backoff.reset();
  102. debug!("Incomming connection from {}", addr);
  103. let services = self.services.clone();
  104. let control_channels = self.control_channels.clone();
  105. tokio::spawn(async move {
  106. if let Err(err) = handle_connection(conn, addr, services, control_channels).await.with_context(||"Failed to handle a connection to `server.bind_addr`") {
  107. error!("{:?}", err);
  108. }
  109. }.instrument(info_span!("handle_connection", %addr)));
  110. }
  111. }
  112. },
  113. _ = tokio::signal::ctrl_c() => {
  114. info!("Shuting down gracefully...");
  115. break;
  116. }
  117. }
  118. }
  119. Ok(())
  120. }
  121. }
  122. async fn handle_connection<T: 'static + Transport>(
  123. mut conn: T::Stream,
  124. addr: SocketAddr,
  125. services: Arc<RwLock<HashMap<ServiceDigest, ServerServiceConfig>>>,
  126. control_channels: Arc<RwLock<ControlChannelMap<T>>>,
  127. ) -> Result<()> {
  128. // Read hello
  129. let hello = read_hello(&mut conn).await?;
  130. match hello {
  131. ControlChannelHello(_, service_digest) => {
  132. do_control_channel_handshake(conn, addr, services, control_channels, service_digest)
  133. .await?;
  134. }
  135. DataChannelHello(_, nonce) => {
  136. do_data_channel_handshake(conn, control_channels, nonce).await?;
  137. }
  138. }
  139. Ok(())
  140. }
  141. async fn do_control_channel_handshake<T: 'static + Transport>(
  142. mut conn: T::Stream,
  143. addr: SocketAddr,
  144. services: Arc<RwLock<HashMap<ServiceDigest, ServerServiceConfig>>>,
  145. control_channels: Arc<RwLock<ControlChannelMap<T>>>,
  146. service_digest: ServiceDigest,
  147. ) -> Result<()> {
  148. info!("New control channel incomming from {}", addr);
  149. // Generate a nonce
  150. let mut nonce = vec![0u8; HASH_WIDTH_IN_BYTES];
  151. rand::thread_rng().fill_bytes(&mut nonce);
  152. // Send hello
  153. let hello_send = Hello::ControlChannelHello(
  154. protocol::CURRENT_PROTO_VRESION,
  155. nonce.clone().try_into().unwrap(),
  156. );
  157. conn.write_all(&bincode::serialize(&hello_send).unwrap())
  158. .await?;
  159. // Lookup the service
  160. let services_guard = services.read().await;
  161. let service_config = match services_guard.get(&service_digest) {
  162. Some(v) => v,
  163. None => {
  164. conn.write_all(&bincode::serialize(&Ack::ServiceNotExist).unwrap())
  165. .await?;
  166. bail!("No such a service {}", hex::encode(&service_digest));
  167. }
  168. };
  169. let service_name = &service_config.name;
  170. // Calculate the checksum
  171. let mut concat = Vec::from(service_config.token.as_ref().unwrap().as_bytes());
  172. concat.append(&mut nonce);
  173. // Read auth
  174. let protocol::Auth(d) = read_auth(&mut conn).await?;
  175. // Validate
  176. let session_key = protocol::digest(&concat);
  177. if session_key != d {
  178. conn.write_all(&bincode::serialize(&Ack::AuthFailed).unwrap())
  179. .await?;
  180. debug!(
  181. "Expect {}, but got {}",
  182. hex::encode(session_key),
  183. hex::encode(d)
  184. );
  185. bail!("Service {} failed the authentication", service_name);
  186. } else {
  187. let mut h = control_channels.write().await;
  188. if let Some(_) = h.remove1(&service_digest) {
  189. warn!(
  190. "Dropping previous control channel for digest {}",
  191. hex::encode(service_digest)
  192. );
  193. }
  194. let service_config = service_config.clone();
  195. drop(services_guard);
  196. // Send ack
  197. conn.write_all(&bincode::serialize(&Ack::Ok).unwrap())
  198. .await?;
  199. info!(service = %service_config.name, "Control channel established");
  200. let handle = ControlChannelHandle::new(conn, service_config);
  201. // Drop the old handle
  202. let _ = h.insert(service_digest, session_key, handle);
  203. }
  204. Ok(())
  205. }
  206. async fn do_data_channel_handshake<T: Transport>(
  207. conn: T::Stream,
  208. control_channels: Arc<RwLock<ControlChannelMap<T>>>,
  209. nonce: Nonce,
  210. ) -> Result<()> {
  211. // Validate
  212. let control_channels_guard = control_channels.read().await;
  213. match control_channels_guard.get2(&nonce) {
  214. Some(c_ch) => {
  215. // Send the data channel to the corresponding control channel
  216. c_ch.conn_pool.data_ch_tx.send(conn).await?;
  217. }
  218. None => {
  219. warn!("Data channel has incorrect nonce");
  220. }
  221. }
  222. Ok(())
  223. }
  224. struct ControlChannel<T: Transport> {
  225. conn: T::Stream,
  226. service: ServerServiceConfig,
  227. shutdown_rx: oneshot::Receiver<bool>,
  228. visitor_tx: mpsc::Sender<TcpStream>,
  229. }
  230. struct ControlChannelHandle<T: Transport> {
  231. _shutdown_tx: oneshot::Sender<bool>,
  232. conn_pool: ConnectionPoolHandle<T>,
  233. }
  234. impl<T: 'static + Transport> ControlChannelHandle<T> {
  235. fn new(conn: T::Stream, service: ServerServiceConfig) -> ControlChannelHandle<T> {
  236. let (_shutdown_tx, shutdown_rx) = oneshot::channel::<bool>();
  237. let name = service.name.clone();
  238. let conn_pool = ConnectionPoolHandle::new();
  239. let actor: ControlChannel<T> = ControlChannel {
  240. conn,
  241. shutdown_rx,
  242. service,
  243. visitor_tx: conn_pool.visitor_tx.clone(),
  244. };
  245. tokio::spawn(async move {
  246. if let Err(err) = actor.run().await {
  247. error!(%name, "{}", err);
  248. }
  249. });
  250. ControlChannelHandle {
  251. _shutdown_tx,
  252. conn_pool,
  253. }
  254. }
  255. }
  256. impl<T: Transport> ControlChannel<T> {
  257. #[tracing::instrument(skip(self), fields(service = %self.service.name))]
  258. async fn run(mut self) -> Result<()> {
  259. let l = match TcpListener::bind(&self.service.bind_addr).await {
  260. Ok(v) => v,
  261. Err(e) => {
  262. let duration = Duration::from_secs(1);
  263. error!(
  264. "Failed to listen on service.bind_addr: {}. Retry in {:?}...",
  265. e, duration
  266. );
  267. time::sleep(duration).await;
  268. TcpListener::bind(&self.service.bind_addr).await?
  269. }
  270. };
  271. info!("Listening at {}", &self.service.bind_addr);
  272. let (data_req_tx, mut data_req_rx) = mpsc::unbounded_channel::<u8>();
  273. tokio::spawn(async move {
  274. let cmd = bincode::serialize(&ControlChannelCmd::CreateDataChannel).unwrap();
  275. while data_req_rx.recv().await.is_some() {
  276. if self.conn.write_all(&cmd).await.is_err() {
  277. break;
  278. }
  279. }
  280. });
  281. for _i in 0..POOL_SIZE {
  282. if let Err(e) = data_req_tx.send(0) {
  283. error!("Failed to request data channel {}", e);
  284. };
  285. }
  286. let mut backoff = ExponentialBackoff {
  287. max_interval: Duration::from_secs(1),
  288. max_elapsed_time: None,
  289. ..Default::default()
  290. };
  291. loop {
  292. tokio::select! {
  293. val = l.accept() => {
  294. match val {
  295. Err(e) => {
  296. error!("{}. Sleep for a while", e);
  297. if let Some(d) = backoff.next_backoff() {
  298. time::sleep(d).await;
  299. } else {
  300. error!("Too many retries. Aborting...");
  301. break;
  302. }
  303. },
  304. Ok((incoming, addr)) => {
  305. if let Err(e) = data_req_tx.send(0) {
  306. error!("{}", e);
  307. break;
  308. };
  309. backoff.reset();
  310. debug!("New visitor from {}", addr);
  311. let _ = self.visitor_tx.send(incoming).await;
  312. }
  313. }
  314. },
  315. _ = &mut self.shutdown_rx => {
  316. break;
  317. }
  318. }
  319. }
  320. info!("Service shuting down");
  321. Ok(())
  322. }
  323. }
  324. #[derive(Debug)]
  325. struct ConnectionPool<T: Transport> {
  326. visitor_rx: mpsc::Receiver<TcpStream>,
  327. data_ch_rx: mpsc::Receiver<T::Stream>,
  328. }
  329. struct ConnectionPoolHandle<T: Transport> {
  330. visitor_tx: mpsc::Sender<TcpStream>,
  331. data_ch_tx: mpsc::Sender<T::Stream>,
  332. }
  333. impl<T: 'static + Transport> ConnectionPoolHandle<T> {
  334. fn new() -> ConnectionPoolHandle<T> {
  335. let (data_ch_tx, data_ch_rx) = mpsc::channel(CHAN_SIZE * 2);
  336. let (visitor_tx, visitor_rx) = mpsc::channel(CHAN_SIZE);
  337. let conn_pool: ConnectionPool<T> = ConnectionPool {
  338. data_ch_rx,
  339. visitor_rx,
  340. };
  341. tokio::spawn(async move { conn_pool.run().await });
  342. ConnectionPoolHandle {
  343. data_ch_tx,
  344. visitor_tx,
  345. }
  346. }
  347. }
  348. impl<T: Transport> ConnectionPool<T> {
  349. #[tracing::instrument]
  350. async fn run(mut self) {
  351. while let Some(mut visitor) = self.visitor_rx.recv().await {
  352. if let Some(mut ch) = self.data_ch_rx.recv().await {
  353. tokio::spawn(async move {
  354. let cmd = bincode::serialize(&DataChannelCmd::StartForward).unwrap();
  355. if ch.write_all(&cmd).await.is_ok() {
  356. let _ = copy_bidirectional(&mut ch, &mut visitor).await;
  357. }
  358. });
  359. } else {
  360. break;
  361. }
  362. }
  363. }
  364. }