Pārlūkot izejas kodu

refactor: facilitate tests

Yujia Qiao 4 gadi atpakaļ
vecāks
revīzija
b8e824849a
4 mainītis faili ar 43 papildinājumiem un 21 dzēšanām
  1. 13 9
      src/client.rs
  2. 4 5
      src/lib.rs
  3. 20 1
      src/main.rs
  4. 6 6
      src/server.rs

+ 13 - 9
src/client.rs

@@ -13,12 +13,12 @@ use backoff::ExponentialBackoff;
 
 use tokio::io::{copy_bidirectional, AsyncWriteExt};
 use tokio::net::TcpStream;
-use tokio::sync::oneshot;
+use tokio::sync::{broadcast, oneshot};
 use tokio::time::{self, Duration};
 use tracing::{debug, error, info, instrument, Instrument, Span};
 
 // The entrypoint of running a client
-pub async fn run_client(config: &Config) -> Result<()> {
+pub async fn run_client(config: &Config, shutdown_rx: broadcast::Receiver<bool>) -> Result<()> {
     let config = match &config.client {
         Some(v) => v,
         None => {
@@ -29,11 +29,11 @@ pub async fn run_client(config: &Config) -> Result<()> {
     match config.transport.transport_type {
         TransportType::Tcp => {
             let mut client = Client::<TcpTransport>::from(config).await?;
-            client.run().await
+            client.run(shutdown_rx).await
         }
         TransportType::Tls => {
             let mut client = Client::<TlsTransport>::from(config).await?;
-            client.run().await
+            client.run(shutdown_rx).await
         }
     }
 }
@@ -54,12 +54,16 @@ impl<'a, T: 'static + Transport> Client<'a, T> {
         Ok(Client {
             config,
             service_handles: HashMap::new(),
-            transport: Arc::new(*T::new(&config.transport).await?),
+            transport: Arc::new(
+                *T::new(&config.transport)
+                    .await
+                    .with_context(|| "Failed to create the transport")?,
+            ),
         })
     }
 
     // The entrypoint of Client
-    async fn run(&mut self) -> Result<()> {
+    async fn run(&mut self, mut shutdown_rx: broadcast::Receiver<bool>) -> Result<()> {
         for (name, config) in &self.config.services {
             // Create a control channel for each service defined
             let handle = ControlChannelHandle::new(
@@ -74,9 +78,9 @@ impl<'a, T: 'static + Transport> Client<'a, T> {
         // Wait for the shutdown signal
         loop {
             tokio::select! {
-                val = tokio::signal::ctrl_c() => {
+                val = shutdown_rx.recv() => {
                     match val {
-                        Ok(()) => {}
+                        Ok(_) => {}
                         Err(err) => {
                             error!("Unable to listen for shutdown signal: {}", err);
                         }
@@ -258,7 +262,7 @@ impl ControlChannelHandle {
                     .await
                     .with_context(|| "Failed to run the control channel")
                 {
-                    let duration = Duration::from_secs(2);
+                    let duration = Duration::from_secs(1);
                     error!("{:?}\n\nRetry in {:?}...", err, duration);
                     time::sleep(duration).await;
                 }

+ 4 - 5
src/lib.rs

@@ -11,16 +11,15 @@ pub use cli::Cli;
 pub use config::Config;
 
 use anyhow::{anyhow, Result};
+use tokio::sync::broadcast;
 use tracing::debug;
 
 use client::run_client;
 use server::run_server;
 
-pub async fn run(args: &Cli) -> Result<()> {
+pub async fn run(args: &Cli, shutdown_rx: broadcast::Receiver<bool>) -> Result<()> {
     let config = Config::from_file(&args.config_path).await?;
 
-    tracing_subscriber::fmt::init();
-
     debug!("{:?}", config);
 
     // Raise `nofile` limit on linux and mac
@@ -28,8 +27,8 @@ pub async fn run(args: &Cli) -> Result<()> {
 
     match determine_run_mode(&config, args) {
         RunMode::Undetermine => Err(anyhow!("Cannot determine running as a server or a client")),
-        RunMode::Client => run_client(&config).await,
-        RunMode::Server => run_server(&config).await,
+        RunMode::Client => run_client(&config, shutdown_rx).await,
+        RunMode::Server => run_server(&config, shutdown_rx).await,
     }
 }
 

+ 20 - 1
src/main.rs

@@ -1,9 +1,28 @@
 use anyhow::Result;
 use clap::Parser;
 use rathole::{run, Cli};
+use tokio::{signal, sync::broadcast};
 
 #[tokio::main]
 async fn main() -> Result<()> {
     let args = Cli::parse();
-    run(&args).await
+
+    let (shutdown_tx, shutdown_rx) = broadcast::channel::<bool>(1);
+    tokio::spawn(async move {
+        if let Err(e) = signal::ctrl_c().await {
+            // Something really weird happened. So just panic
+            panic!("Failed to listen for the ctrl-c signal: {:?}", e);
+        }
+
+        if let Err(e) = shutdown_tx.send(true) {
+            // shutdown signal must be catched and handle properly
+            // `rx` must not be dropped
+            panic!("Failed to send shutdown signal: {:?}", e);
+        }
+    });
+
+    // TODO: use level from config
+    tracing_subscriber::fmt::init();
+
+    run(&args, shutdown_rx).await
 }

+ 6 - 6
src/server.rs

@@ -15,7 +15,7 @@ use std::sync::Arc;
 use std::time::Duration;
 use tokio::io::{self, copy_bidirectional, AsyncWriteExt};
 use tokio::net::{TcpListener, TcpStream};
-use tokio::sync::{mpsc, oneshot, RwLock};
+use tokio::sync::{broadcast, mpsc, oneshot, RwLock};
 use tokio::time;
 use tracing::{debug, error, info, info_span, warn, Instrument};
 
@@ -26,7 +26,7 @@ const POOL_SIZE: usize = 64; // The number of cached connections
 const CHAN_SIZE: usize = 2048; // The capacity of various chans
 
 // The entrypoint of running a server
-pub async fn run_server(config: &Config) -> Result<()> {
+pub async fn run_server(config: &Config, shutdown_rx: broadcast::Receiver<bool>) -> Result<()> {
     let config = match &config.server {
             Some(config) => config,
             None => {
@@ -38,11 +38,11 @@ pub async fn run_server(config: &Config) -> Result<()> {
     match config.transport.transport_type {
         TransportType::Tcp => {
             let mut server = Server::<TcpTransport>::from(config).await?;
-            server.run().await?;
+            server.run(shutdown_rx).await?;
         }
         TransportType::Tls => {
             let mut server = Server::<TlsTransport>::from(config).await?;
-            server.run().await?;
+            server.run(shutdown_rx).await?;
         }
     }
 
@@ -91,7 +91,7 @@ impl<'a, T: 'static + Transport> Server<'a, T> {
     }
 
     // The entry point of Server
-    pub async fn run(&mut self) -> Result<()> {
+    pub async fn run(&mut self, mut shutdown_rx: broadcast::Receiver<bool>) -> Result<()> {
         // Listen at `server.bind_addr`
         let l = self
             .transport
@@ -146,7 +146,7 @@ impl<'a, T: 'static + Transport> Server<'a, T> {
                     }
                 },
                 // Wait for the shutdown signal
-                _ = tokio::signal::ctrl_c() => {
+                _ = shutdown_rx.recv() => {
                     info!("Shuting down gracefully...");
                     break;
                 }