Pārlūkot izejas kodu

refactor: fix clippy, merge imports

Fix lints of clippy
Merge imports
Yujia Qiao 4 gadi atpakaļ
vecāks
revīzija
f92398ea31
9 mainītis faili ar 59 papildinājumiem un 85 dzēšanām
  1. 1 0
      .rustfmt.toml
  2. 13 19
      src/client.rs
  3. 3 5
      src/config.rs
  4. 10 14
      src/lib.rs
  5. 0 1
      src/main.rs
  6. 18 25
      src/server.rs
  7. 3 5
      src/transport/mod.rs
  8. 3 2
      src/transport/tcp.rs
  9. 8 14
      src/transport/tls.rs

+ 1 - 0
.rustfmt.toml

@@ -0,0 +1 @@
+imports_granularity = "module"

+ 13 - 19
src/client.rs

@@ -2,12 +2,11 @@ use std::collections::HashMap;
 use std::sync::Arc;
 
 use crate::config::{ClientConfig, ClientServiceConfig, Config, TransportType};
+use crate::protocol::Hello::{self, *};
 use crate::protocol::{
-    self, Ack, Auth, ControlChannelCmd, DataChannelCmd,
-    Hello::{self, *},
-    CURRENT_PROTO_VRESION, HASH_WIDTH_IN_BYTES,
+    self, read_ack, read_control_cmd, read_data_cmd, read_hello, Ack, Auth, ControlChannelCmd,
+    DataChannelCmd, CURRENT_PROTO_VRESION, HASH_WIDTH_IN_BYTES,
 };
-use crate::protocol::{read_ack, read_control_cmd, read_data_cmd, read_hello};
 use crate::transport::{TcpTransport, TlsTransport, Transport};
 use anyhow::{anyhow, bail, Context, Result};
 use backoff::ExponentialBackoff;
@@ -28,11 +27,11 @@ pub async fn run_client(config: &Config) -> Result<()> {
 
     match config.transport.transport_type {
         TransportType::Tcp => {
-            let mut client = Client::<TcpTransport>::from(&config).await?;
+            let mut client = Client::<TcpTransport>::from(config).await?;
             client.run().await
         }
         TransportType::Tls => {
-            let mut client = Client::<TlsTransport>::from(&config).await?;
+            let mut client = Client::<TlsTransport>::from(config).await?;
             client.run().await
         }
     }
@@ -244,19 +243,14 @@ impl ControlChannelHandle {
 
         tokio::spawn(
             async move {
-                loop {
-                    if let Err(err) = s
-                        .run()
-                        .await
-                        .with_context(|| "Failed to run the control channel")
-                    {
-                        let duration = Duration::from_secs(2);
-                        error!("{:?}\n\nRetry in {:?}...", err, duration);
-                        time::sleep(duration).await;
-                    } else {
-                        // Shutdown
-                        break;
-                    }
+                while let Err(err) = s
+                    .run()
+                    .await
+                    .with_context(|| "Failed to run the control channel")
+                {
+                    let duration = Duration::from_secs(2);
+                    error!("{:?}\n\nRetry in {:?}...", err, duration);
+                    time::sleep(duration).await;
                 }
             }
             .instrument(Span::current()),

+ 3 - 5
src/config.rs

@@ -1,9 +1,8 @@
 use anyhow::{anyhow, bail, Context, Result};
 use serde::{Deserialize, Serialize};
 use std::collections::HashMap;
-use std::path::PathBuf;
+use std::path::Path;
 use tokio::fs;
-use toml;
 
 #[derive(Debug, Serialize, Deserialize, Copy, Clone)]
 pub enum TransportType {
@@ -81,8 +80,7 @@ pub struct Config {
 
 impl Config {
     fn from_str(s: &str) -> Result<Config> {
-        let mut config: Config =
-            toml::from_str(&s).with_context(|| "Failed to parse the config")?;
+        let mut config: Config = toml::from_str(s).with_context(|| "Failed to parse the config")?;
 
         if let Some(server) = config.server.as_mut() {
             Config::validate_server_config(server)?;
@@ -158,7 +156,7 @@ impl Config {
         }
     }
 
-    pub async fn from_file(path: &PathBuf) -> Result<Config> {
+    pub async fn from_file(path: &Path) -> Result<Config> {
         let s: String = fs::read_to_string(path)
             .await
             .with_context(|| format!("Failed to read the config {:?}", path))?;

+ 10 - 14
src/lib.rs

@@ -26,7 +26,7 @@ pub async fn run(args: &Cli) -> Result<()> {
     // Raise `nofile` limit on linux and mac
     fdlimit::raise_fd_limit();
 
-    match determine_run_mode(&config, &args) {
+    match determine_run_mode(&config, args) {
         RunMode::Undetermine => Err(anyhow!("Cannot determine running as a server or a client")),
         RunMode::Client => run_client(&config).await,
         RunMode::Server => run_server(&config).await,
@@ -44,20 +44,16 @@ fn determine_run_mode(config: &Config, args: &Cli) -> RunMode {
     use RunMode::*;
     if args.client && args.server {
         Undetermine
+    } else if args.client {
+        Client
+    } else if args.server {
+        Server
+    } else if config.client.is_some() && config.server.is_none() {
+        Client
+    } else if config.server.is_some() && config.client.is_none() {
+        Server
     } else {
-        if args.client {
-            Client
-        } else if args.server {
-            Server
-        } else {
-            if config.server.is_some() && config.client.is_none() {
-                Server
-            } else if config.client.is_some() && config.server.is_none() {
-                Client
-            } else {
-                Undetermine
-            }
-        }
+        Undetermine
     }
 }
 

+ 0 - 1
src/main.rs

@@ -1,7 +1,6 @@
 use anyhow::Result;
 use clap::Parser;
 use rathole::{run, Cli};
-use tokio;
 
 #[tokio::main]
 async fn main() -> Result<()> {

+ 18 - 25
src/server.rs

@@ -1,13 +1,13 @@
 use crate::config::{Config, ServerConfig, ServerServiceConfig, TransportType};
 use crate::multi_map::MultiMap;
+use crate::protocol::Hello::{ControlChannelHello, DataChannelHello};
 use crate::protocol::{
-    self, Ack, ControlChannelCmd, DataChannelCmd, Hello, Hello::ControlChannelHello,
-    Hello::DataChannelHello, HASH_WIDTH_IN_BYTES,
+    self, read_auth, read_hello, Ack, ControlChannelCmd, DataChannelCmd, Hello, HASH_WIDTH_IN_BYTES,
 };
-use crate::protocol::{read_auth, read_hello};
 use crate::transport::{TcpTransport, TlsTransport, Transport};
 use anyhow::{anyhow, bail, Context, Result};
-use backoff::{backoff::Backoff, ExponentialBackoff};
+use backoff::backoff::Backoff;
+use backoff::ExponentialBackoff;
 use rand::RngCore;
 use std::collections::HashMap;
 use std::net::SocketAddr;
@@ -15,8 +15,7 @@ use std::sync::Arc;
 use std::time::Duration;
 use tokio::io::{self, copy_bidirectional, AsyncWriteExt};
 use tokio::net::{TcpListener, TcpStream};
-use tokio::sync::mpsc;
-use tokio::sync::{oneshot, RwLock};
+use tokio::sync::{mpsc, oneshot, RwLock};
 use tokio::time;
 use tracing::{debug, error, info, info_span, warn, Instrument};
 
@@ -190,9 +189,7 @@ async fn do_control_channel_handshake<T: 'static + Transport>(
     concat.append(&mut nonce);
 
     // Read auth
-    let d = match read_auth(&mut conn).await? {
-        protocol::Auth(v) => v,
-    };
+    let protocol::Auth(d) = read_auth(&mut conn).await?;
 
     // Validate
     let session_key = protocol::digest(&concat);
@@ -259,13 +256,13 @@ struct ControlChannel<T: Transport> {
 }
 
 struct ControlChannelHandle<T: Transport> {
-    shutdown_tx: oneshot::Sender<bool>,
+    _shutdown_tx: oneshot::Sender<bool>,
     conn_pool: ConnectionPoolHandle<T>,
 }
 
 impl<T: 'static + Transport> ControlChannelHandle<T> {
     fn new(conn: T::Stream, service: ServerServiceConfig) -> ControlChannelHandle<T> {
-        let (shutdown_tx, shutdown_rx) = oneshot::channel::<bool>();
+        let (_shutdown_tx, shutdown_rx) = oneshot::channel::<bool>();
         let name = service.name.clone();
         let conn_pool = ConnectionPoolHandle::new();
         let actor: ControlChannel<T> = ControlChannel {
@@ -282,7 +279,7 @@ impl<T: 'static + Transport> ControlChannelHandle<T> {
         });
 
         ControlChannelHandle {
-            shutdown_tx,
+            _shutdown_tx,
             conn_pool,
         }
     }
@@ -309,7 +306,7 @@ impl<T: Transport> ControlChannel<T> {
         let (data_req_tx, mut data_req_rx) = mpsc::unbounded_channel::<u8>();
         tokio::spawn(async move {
             let cmd = bincode::serialize(&ControlChannelCmd::CreateDataChannel).unwrap();
-            while let Some(_) = data_req_rx.recv().await {
+            while data_req_rx.recv().await.is_some() {
                 if self.conn.write_all(&cmd).await.is_err() {
                     break;
                 }
@@ -396,18 +393,14 @@ impl<T: 'static + Transport> ConnectionPoolHandle<T> {
 impl<T: Transport> ConnectionPool<T> {
     #[tracing::instrument]
     async fn run(mut self) {
-        loop {
-            if let Some(mut visitor) = self.visitor_rx.recv().await {
-                if let Some(mut ch) = self.data_ch_rx.recv().await {
-                    tokio::spawn(async move {
-                        let cmd = bincode::serialize(&DataChannelCmd::StartForward).unwrap();
-                        if ch.write_all(&cmd).await.is_ok() {
-                            let _ = copy_bidirectional(&mut ch, &mut visitor).await;
-                        }
-                    });
-                } else {
-                    break;
-                }
+        while let Some(mut visitor) = self.visitor_rx.recv().await {
+            if let Some(mut ch) = self.data_ch_rx.recv().await {
+                tokio::spawn(async move {
+                    let cmd = bincode::serialize(&DataChannelCmd::StartForward).unwrap();
+                    if ch.write_all(&cmd).await.is_ok() {
+                        let _ = copy_bidirectional(&mut ch, &mut visitor).await;
+                    }
+                });
             } else {
                 break;
             }

+ 3 - 5
src/transport/mod.rs

@@ -3,10 +3,8 @@ use anyhow::Result;
 use async_trait::async_trait;
 use std::fmt::Debug;
 use std::net::SocketAddr;
-use tokio::{
-    io::{AsyncRead, AsyncWrite},
-    net::ToSocketAddrs,
-};
+use tokio::io::{AsyncRead, AsyncWrite};
+use tokio::net::ToSocketAddrs;
 
 #[async_trait]
 pub trait Transport: Debug + Send + Sync {
@@ -16,7 +14,7 @@ pub trait Transport: Debug + Send + Sync {
     async fn new(config: &TransportConfig) -> Result<Box<Self>>;
     async fn bind<T: ToSocketAddrs + Send + Sync>(&self, addr: T) -> Result<Self::Acceptor>;
     async fn accept(&self, a: &Self::Acceptor) -> Result<(Self::Stream, SocketAddr)>;
-    async fn connect(&self, addr: &String) -> Result<Self::Stream>;
+    async fn connect(&self, addr: &str) -> Result<Self::Stream>;
 }
 
 mod tcp;

+ 3 - 2
src/transport/tcp.rs

@@ -1,4 +1,5 @@
-use crate::{config::TransportConfig, helper::set_tcp_keepalive};
+use crate::config::TransportConfig;
+use crate::helper::set_tcp_keepalive;
 
 use super::Transport;
 use anyhow::Result;
@@ -28,7 +29,7 @@ impl Transport for TcpTransport {
         Ok((s, addr))
     }
 
-    async fn connect(&self, addr: &String) -> Result<Self::Stream> {
+    async fn connect(&self, addr: &str) -> Result<Self::Stream> {
         let s = TcpStream::connect(addr).await?;
         if let Err(e) = set_tcp_keepalive(&s) {
             error!(

+ 8 - 14
src/transport/tls.rs

@@ -1,20 +1,14 @@
 use std::net::SocketAddr;
 
 use super::Transport;
-use crate::{
-    config::{TlsConfig, TransportConfig},
-    helper::set_tcp_keepalive,
-};
+use crate::config::{TlsConfig, TransportConfig};
+use crate::helper::set_tcp_keepalive;
 use anyhow::{anyhow, Context, Result};
 use async_trait::async_trait;
-use tokio::{
-    fs,
-    net::{TcpListener, TcpStream, ToSocketAddrs},
-};
-use tokio_native_tls::{
-    native_tls::{self, Certificate, Identity},
-    TlsAcceptor, TlsConnector, TlsStream,
-};
+use tokio::fs;
+use tokio::net::{TcpListener, TcpStream, ToSocketAddrs};
+use tokio_native_tls::native_tls::{self, Certificate, Identity};
+use tokio_native_tls::{TlsAcceptor, TlsConnector, TlsStream};
 use tracing::error;
 
 #[derive(Debug)]
@@ -39,7 +33,7 @@ impl Transport for TlsTransport {
         let connector = match config.trusted_root.as_ref() {
             Some(path) => {
                 let s = fs::read_to_string(path).await?;
-                let cert = Certificate::from_pem(&s.as_bytes())?;
+                let cert = Certificate::from_pem(s.as_bytes())?;
                 let connector = native_tls::TlsConnector::builder()
                     .add_root_certificate(cert)
                     .build()?;
@@ -74,7 +68,7 @@ impl Transport for TlsTransport {
         Ok((conn, addr))
     }
 
-    async fn connect(&self, addr: &String) -> Result<Self::Stream> {
+    async fn connect(&self, addr: &str) -> Result<Self::Stream> {
         let conn = TcpStream::connect(&addr).await?;
         if let Err(e) = set_tcp_keepalive(&conn) {
             error!(