Explorar o código

fix: support IPv6 `local_address` for UDP

Yujia Qiao %!s(int64=4) %!d(string=hai) anos
pai
achega
7a7eef11bc
Modificáronse 1 ficheiros con 42 adicións e 4 borrados
  1. 42 4
      src/helper.rs

+ 42 - 4
src/helper.rs

@@ -5,9 +5,9 @@ use std::{
     time::Duration,
 };
 
-use anyhow::{Context, Result};
+use anyhow::{anyhow, Context, Result};
 use socket2::{SockRef, TcpKeepalive};
-use tokio::net::{TcpStream, ToSocketAddrs, UdpSocket};
+use tokio::net::{lookup_host, TcpStream, ToSocketAddrs, UdpSocket};
 use tracing::error;
 
 // Tokio hesitates to expose this option...So we have to do it on our own :(
@@ -39,8 +39,17 @@ pub fn feature_not_compile(feature: &str) -> ! {
 
 /// Create a UDP socket and connect to `addr`
 pub async fn udp_connect<A: ToSocketAddrs>(addr: A) -> Result<UdpSocket> {
-    // FIXME: This only works for IPv4
-    let s = UdpSocket::bind("0.0.0.0:0").await?;
+    let addr = lookup_host(addr)
+        .await?
+        .next()
+        .ok_or(anyhow!("Failed to lookup the host"))?;
+
+    let bind_addr = match addr {
+        SocketAddr::V4(_) => "0.0.0.0:0",
+        SocketAddr::V6(_) => ":::0",
+    };
+
+    let s = UdpSocket::bind(bind_addr).await?;
     s.connect(addr).await?;
     Ok(s)
 }
@@ -70,8 +79,12 @@ pub fn floor_to_pow_of_2(x: usize) -> usize {
 
 #[cfg(test)]
 mod test {
+    use tokio::net::UdpSocket;
+
     use crate::helper::{floor_to_pow_of_2, log2_floor};
 
+    use super::udp_connect;
+
     #[test]
     fn test_log2_floor() {
         let t = [
@@ -111,4 +124,29 @@ mod test {
             assert_eq!(floor_to_pow_of_2(t.0), t.1);
         }
     }
+
+    #[tokio::test]
+    async fn test_udp_connect() {
+        let hello = "HELLO";
+
+        let t = [("0.0.0.0:2333", "127.0.0.1:2333"), (":::2333", "::1:2333")];
+        for t in t {
+            let listener = UdpSocket::bind(t.0).await.unwrap();
+
+            let handle = tokio::spawn(async move {
+                let s = udp_connect(t.1).await.unwrap();
+                s.send(hello.as_bytes()).await.unwrap();
+                let mut buf = [0u8; 16];
+                let n = s.recv(&mut buf).await.unwrap();
+                assert_eq!(&buf[..n], hello.as_bytes());
+            });
+
+            let mut buf = [0u8; 16];
+            let (n, addr) = listener.recv_from(&mut buf).await.unwrap();
+            assert_eq!(&buf[..n], hello.as_bytes());
+            listener.send_to(&buf[..n], addr).await.unwrap();
+
+            handle.await.unwrap();
+        }
+    }
 }