helper.rs 3.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152
  1. use std::{
  2. collections::hash_map::DefaultHasher,
  3. hash::{Hash, Hasher},
  4. net::SocketAddr,
  5. time::Duration,
  6. };
  7. use anyhow::{anyhow, Context, Result};
  8. use socket2::{SockRef, TcpKeepalive};
  9. use tokio::net::{lookup_host, TcpStream, ToSocketAddrs, UdpSocket};
  10. use tracing::error;
  11. // Tokio hesitates to expose this option...So we have to do it on our own :(
  12. // The good news is that using socket2 it can be easily done, without losing portability.
  13. // See https://github.com/tokio-rs/tokio/issues/3082
  14. pub fn try_set_tcp_keepalive(conn: &TcpStream) -> Result<()> {
  15. let s = SockRef::from(conn);
  16. let keepalive = TcpKeepalive::new().with_time(Duration::from_secs(30));
  17. s.set_tcp_keepalive(&keepalive)
  18. .with_context(|| "Failed to set keepalive")
  19. }
  20. pub fn set_tcp_keepalive(conn: &TcpStream) {
  21. if let Err(e) = try_set_tcp_keepalive(conn) {
  22. error!(
  23. "Failed to set TCP keepalive. The connection maybe unstable: {:?}",
  24. e
  25. );
  26. }
  27. }
  28. #[allow(dead_code)]
  29. pub fn feature_not_compile(feature: &str) -> ! {
  30. panic!(
  31. "The feature '{}' is not compiled in this binary. Please re-compile rathole",
  32. feature
  33. )
  34. }
  35. /// Create a UDP socket and connect to `addr`
  36. pub async fn udp_connect<A: ToSocketAddrs>(addr: A) -> Result<UdpSocket> {
  37. let addr = lookup_host(addr)
  38. .await?
  39. .next()
  40. .ok_or(anyhow!("Failed to lookup the host"))?;
  41. let bind_addr = match addr {
  42. SocketAddr::V4(_) => "0.0.0.0:0",
  43. SocketAddr::V6(_) => ":::0",
  44. };
  45. let s = UdpSocket::bind(bind_addr).await?;
  46. s.connect(addr).await?;
  47. Ok(s)
  48. }
  49. // FIXME: These functions are for the load balance for UDP. But not used for now.
  50. #[allow(dead_code)]
  51. pub fn hash_socket_addr(a: &SocketAddr) -> u64 {
  52. let mut hasher = DefaultHasher::new();
  53. a.hash(&mut hasher);
  54. hasher.finish()
  55. }
  56. // Wait for the stabilization of https://doc.rust-lang.org/std/primitive.i64.html#method.log2
  57. #[allow(dead_code)]
  58. fn log2_floor(x: usize) -> u8 {
  59. (x as f64).log2().floor() as u8
  60. }
  61. #[allow(dead_code)]
  62. pub fn floor_to_pow_of_2(x: usize) -> usize {
  63. if x == 1 {
  64. return 1;
  65. }
  66. let w = log2_floor(x);
  67. 1 << w
  68. }
  69. #[cfg(test)]
  70. mod test {
  71. use tokio::net::UdpSocket;
  72. use crate::helper::{floor_to_pow_of_2, log2_floor};
  73. use super::udp_connect;
  74. #[test]
  75. fn test_log2_floor() {
  76. let t = [
  77. (2, 1),
  78. (3, 1),
  79. (4, 2),
  80. (8, 3),
  81. (9, 3),
  82. (15, 3),
  83. (16, 4),
  84. (1023, 9),
  85. (1024, 10),
  86. (2000, 10),
  87. (2048, 11),
  88. ];
  89. for t in t {
  90. assert_eq!(log2_floor(t.0), t.1);
  91. }
  92. }
  93. #[test]
  94. fn test_floor_to_pow_of_2() {
  95. let t = [
  96. (1 as usize, 1 as usize),
  97. (2, 2),
  98. (3, 2),
  99. (4, 4),
  100. (5, 4),
  101. (15, 8),
  102. (31, 16),
  103. (33, 32),
  104. (1000, 512),
  105. (1500, 1024),
  106. (2300, 2048),
  107. ];
  108. for t in t {
  109. assert_eq!(floor_to_pow_of_2(t.0), t.1);
  110. }
  111. }
  112. #[tokio::test]
  113. async fn test_udp_connect() {
  114. let hello = "HELLO";
  115. let t = [("0.0.0.0:2333", "127.0.0.1:2333"), (":::2333", "::1:2333")];
  116. for t in t {
  117. let listener = UdpSocket::bind(t.0).await.unwrap();
  118. let handle = tokio::spawn(async move {
  119. let s = udp_connect(t.1).await.unwrap();
  120. s.send(hello.as_bytes()).await.unwrap();
  121. let mut buf = [0u8; 16];
  122. let n = s.recv(&mut buf).await.unwrap();
  123. assert_eq!(&buf[..n], hello.as_bytes());
  124. });
  125. let mut buf = [0u8; 16];
  126. let (n, addr) = listener.recv_from(&mut buf).await.unwrap();
  127. assert_eq!(&buf[..n], hello.as_bytes());
  128. listener.send_to(&buf[..n], addr).await.unwrap();
  129. handle.await.unwrap();
  130. }
  131. }
  132. }