Переглянути джерело

fix: throw errors when the service type or protocol version doesn't match (#112)

* fix: print errors when service types don't match

* fix: validate the protocol version when handshake
Yujia Qiao 4 роки тому
батько
коміт
cdbf8781e4
2 змінених файлів з 30 додано та 7 видалено
  1. 11 6
      src/client.rs
  2. 19 1
      src/protocol.rs

+ 11 - 6
src/client.rs

@@ -1,4 +1,4 @@
-use crate::config::{ClientConfig, ClientServiceConfig, Config, TransportType};
+use crate::config::{ClientConfig, ClientServiceConfig, Config, ServiceType, TransportType};
 use crate::config_watcher::ServiceChange;
 use crate::helper::udp_connect;
 use crate::protocol::Hello::{self, *};
@@ -150,9 +150,9 @@ impl<'a, T: 'static + Transport> Client<'a, T> {
 struct RunDataChannelArgs<T: Transport> {
     session_key: Nonce,
     remote_addr: String,
-    local_addr: String,
     connector: Arc<T>,
     socket_opts: SocketOpts,
+    service: ClientServiceConfig,
 }
 
 async fn do_data_channel_handshake<T: Transport>(
@@ -201,10 +201,16 @@ async fn run_data_channel<T: Transport>(args: Arc<RunDataChannelArgs<T>>) -> Res
     // Forward
     match read_data_cmd(&mut conn).await? {
         DataChannelCmd::StartForwardTcp => {
-            run_data_channel_for_tcp::<T>(conn, &args.local_addr).await?;
+            if args.service.service_type != ServiceType::Tcp {
+                bail!("Expect TCP traffic. Please check the configuration.")
+            }
+            run_data_channel_for_tcp::<T>(conn, &args.service.local_addr).await?;
         }
         DataChannelCmd::StartForwardUdp => {
-            run_data_channel_for_udp::<T>(conn, &args.local_addr).await?;
+            if args.service.service_type != ServiceType::Udp {
+                bail!("Expect UDP traffic. Please check the configuration.")
+            }
+            run_data_channel_for_udp::<T>(conn, &args.service.local_addr).await?;
         }
     }
     Ok(())
@@ -427,15 +433,14 @@ impl<T: 'static + Transport> ControlChannel<T> {
         info!("Control channel established");
 
         let remote_addr = self.remote_addr.clone();
-        let local_addr = self.service.local_addr.clone();
         // Socket options for the data channel
         let socket_opts = SocketOpts::from_client_cfg(&self.service);
         let data_ch_args = Arc::new(RunDataChannelArgs {
             session_key,
             remote_addr,
-            local_addr,
             connector: self.transport.clone(),
             socket_opts,
+            service: self.service.clone(),
         });
 
         loop {

+ 19 - 1
src/protocol.rs

@@ -1,6 +1,6 @@
 pub const HASH_WIDTH_IN_BYTES: usize = 32;
 
-use anyhow::{Context, Result};
+use anyhow::{bail, Context, Result};
 use bytes::{Bytes, BytesMut};
 use lazy_static::lazy_static;
 use serde::{Deserialize, Serialize};
@@ -180,6 +180,24 @@ pub async fn read_hello<T: AsyncRead + AsyncWrite + Unpin>(conn: &mut T) -> Resu
         .await
         .with_context(|| "Failed to read hello")?;
     let hello = bincode::deserialize(&buf).with_context(|| "Failed to deserialize hello")?;
+
+    match hello {
+        Hello::ControlChannelHello(v, _) => {
+            if v != CURRENT_PROTO_VERSION {
+                bail!(
+                    "Protocol version mismatched. Expected {}, got {}. Please update `rathole`.",
+                    CURRENT_PROTO_VERSION,
+                    v
+                );
+            }
+        }
+        Hello::DataChannelHello(v, _) => {
+            // This assert should not fail because the version has already been
+            // checked by ControlChannelHello.
+            assert_eq!(v, CURRENT_PROTO_VERSION);
+        }
+    }
+
     Ok(hello)
 }