Sfoglia il codice sorgente

fix: cancel safety

Yujia Qiao 4 anni fa
parent
commit
d0d4f61efd
3 ha cambiato i file con 25 aggiunte e 20 eliminazioni
  1. 1 1
      src/client.rs
  2. 20 14
      src/protocol.rs
  3. 4 5
      src/server.rs

+ 1 - 1
src/client.rs

@@ -13,7 +13,7 @@ use bytes::{Bytes, BytesMut};
 use std::collections::HashMap;
 use std::net::SocketAddr;
 use std::sync::Arc;
-use tokio::io::{self, copy_bidirectional, AsyncWriteExt};
+use tokio::io::{self, copy_bidirectional, AsyncReadExt, AsyncWriteExt};
 use tokio::net::{TcpStream, UdpSocket};
 use tokio::sync::{broadcast, mpsc, oneshot, RwLock};
 use tokio::time::{self, Duration};

+ 20 - 14
src/protocol.rs

@@ -6,6 +6,7 @@ use lazy_static::lazy_static;
 use serde::{Deserialize, Serialize};
 use std::net::SocketAddr;
 use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt};
+use tracing::trace;
 
 type ProtocolVersion = u8;
 const PROTO_V0: u8 = 0u8;
@@ -70,12 +71,14 @@ pub struct UdpTraffic {
 
 impl UdpTraffic {
     pub async fn write<T: AsyncWrite + Unpin>(&self, writer: &mut T) -> Result<()> {
-        let v = bincode::serialize(&UdpHeader {
+        let hdr = UdpHeader {
             from: self.from,
             len: self.data.len() as UdpPacketLen,
-        })
-        .unwrap();
+        };
+
+        let v = bincode::serialize(&hdr).unwrap();
 
+        trace!("Write {:?} of length {}", hdr, v.len());
         writer.write_u16(v.len() as u16).await?;
         writer.write_all(&v).await?;
 
@@ -90,12 +93,14 @@ impl UdpTraffic {
         from: SocketAddr,
         data: &[u8],
     ) -> Result<()> {
-        let v = bincode::serialize(&UdpHeader {
+        let hdr = UdpHeader {
             from,
             len: data.len() as UdpPacketLen,
-        })
-        .unwrap();
+        };
+
+        let v = bincode::serialize(&hdr).unwrap();
 
+        trace!("Write {:?} of length {}", hdr, v.len());
         writer.write_u16(v.len() as u16).await?;
         writer.write_all(&v).await?;
 
@@ -104,24 +109,25 @@ impl UdpTraffic {
         Ok(())
     }
 
-    pub async fn read<T: AsyncRead + Unpin>(reader: &mut T) -> Result<UdpTraffic> {
-        let len = reader.read_u16().await? as usize;
-
+    pub async fn read<T: AsyncRead + Unpin>(reader: &mut T, hdr_len: u16) -> Result<UdpTraffic> {
         let mut buf = Vec::new();
-        buf.resize(len, 0);
+        buf.resize(hdr_len as usize, 0);
         reader
             .read_exact(&mut buf)
             .await
             .with_context(|| "Failed to read udp header")?;
-        let header: UdpHeader =
-            bincode::deserialize(&buf).with_context(|| "Failed to deserialize udp header")?;
+
+        let hdr: UdpHeader =
+            bincode::deserialize(&buf).with_context(|| "Failed to deserialize UdpHeader")?;
+
+        trace!("hdr {:?}", hdr);
 
         let mut data = BytesMut::new();
-        data.resize(header.len as usize, 0);
+        data.resize(hdr.len as usize, 0);
         reader.read_exact(&mut data).await?;
 
         Ok(UdpTraffic {
-            from: header.from,
+            from: hdr.from,
             data: data.freeze(),
         })
     }

+ 4 - 5
src/server.rs

@@ -14,10 +14,9 @@ use backoff::ExponentialBackoff;
 
 use rand::RngCore;
 use std::collections::HashMap;
-use std::net::SocketAddr;
 use std::sync::Arc;
 use std::time::Duration;
-use tokio::io::{self, copy_bidirectional, AsyncWriteExt};
+use tokio::io::{self, copy_bidirectional, AsyncReadExt, AsyncWriteExt};
 use tokio::net::{TcpListener, TcpStream, UdpSocket};
 use tokio::sync::{broadcast, mpsc, RwLock};
 use tokio::time;
@@ -618,10 +617,10 @@ async fn run_udp_connection_pool<T: Transport>(
             },
 
             // Forward outbound traffic from the client to the visitor
-            t = UdpTraffic::read(&mut conn) => {
-                let t = t?;
+            hdr_len = conn.read_u16() => {
+                let t = UdpTraffic::read(&mut conn, hdr_len?).await?;
                 l.send_to(&t.data, t.from).await?;
-            },
+            }
 
             _ = shutdown_rx.recv() => {
                 break;