|
|
@@ -5,9 +5,9 @@ use std::{
|
|
|
time::Duration,
|
|
|
};
|
|
|
|
|
|
-use anyhow::{Context, Result};
|
|
|
+use anyhow::{anyhow, Context, Result};
|
|
|
use socket2::{SockRef, TcpKeepalive};
|
|
|
-use tokio::net::{TcpStream, ToSocketAddrs, UdpSocket};
|
|
|
+use tokio::net::{lookup_host, TcpStream, ToSocketAddrs, UdpSocket};
|
|
|
use tracing::error;
|
|
|
|
|
|
// Tokio hesitates to expose this option...So we have to do it on our own :(
|
|
|
@@ -39,8 +39,17 @@ pub fn feature_not_compile(feature: &str) -> ! {
|
|
|
|
|
|
/// Create a UDP socket and connect to `addr`
|
|
|
pub async fn udp_connect<A: ToSocketAddrs>(addr: A) -> Result<UdpSocket> {
|
|
|
- // FIXME: This only works for IPv4
|
|
|
- let s = UdpSocket::bind("0.0.0.0:0").await?;
|
|
|
+ let addr = lookup_host(addr)
|
|
|
+ .await?
|
|
|
+ .next()
|
|
|
+ .ok_or(anyhow!("Failed to lookup the host"))?;
|
|
|
+
|
|
|
+ let bind_addr = match addr {
|
|
|
+ SocketAddr::V4(_) => "0.0.0.0:0",
|
|
|
+ SocketAddr::V6(_) => ":::0",
|
|
|
+ };
|
|
|
+
|
|
|
+ let s = UdpSocket::bind(bind_addr).await?;
|
|
|
s.connect(addr).await?;
|
|
|
Ok(s)
|
|
|
}
|
|
|
@@ -70,8 +79,12 @@ pub fn floor_to_pow_of_2(x: usize) -> usize {
|
|
|
|
|
|
#[cfg(test)]
|
|
|
mod test {
|
|
|
+ use tokio::net::UdpSocket;
|
|
|
+
|
|
|
use crate::helper::{floor_to_pow_of_2, log2_floor};
|
|
|
|
|
|
+ use super::udp_connect;
|
|
|
+
|
|
|
#[test]
|
|
|
fn test_log2_floor() {
|
|
|
let t = [
|
|
|
@@ -111,4 +124,29 @@ mod test {
|
|
|
assert_eq!(floor_to_pow_of_2(t.0), t.1);
|
|
|
}
|
|
|
}
|
|
|
+
|
|
|
+ #[tokio::test]
|
|
|
+ async fn test_udp_connect() {
|
|
|
+ let hello = "HELLO";
|
|
|
+
|
|
|
+ let t = [("0.0.0.0:2333", "127.0.0.1:2333"), (":::2333", "::1:2333")];
|
|
|
+ for t in t {
|
|
|
+ let listener = UdpSocket::bind(t.0).await.unwrap();
|
|
|
+
|
|
|
+ let handle = tokio::spawn(async move {
|
|
|
+ let s = udp_connect(t.1).await.unwrap();
|
|
|
+ s.send(hello.as_bytes()).await.unwrap();
|
|
|
+ let mut buf = [0u8; 16];
|
|
|
+ let n = s.recv(&mut buf).await.unwrap();
|
|
|
+ assert_eq!(&buf[..n], hello.as_bytes());
|
|
|
+ });
|
|
|
+
|
|
|
+ let mut buf = [0u8; 16];
|
|
|
+ let (n, addr) = listener.recv_from(&mut buf).await.unwrap();
|
|
|
+ assert_eq!(&buf[..n], hello.as_bytes());
|
|
|
+ listener.send_to(&buf[..n], addr).await.unwrap();
|
|
|
+
|
|
|
+ handle.await.unwrap();
|
|
|
+ }
|
|
|
+ }
|
|
|
}
|