helper.rs 3.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126
  1. use std::{
  2. collections::hash_map::DefaultHasher,
  3. hash::{Hash, Hasher},
  4. net::SocketAddr,
  5. time::Duration,
  6. };
  7. use anyhow::{anyhow, Result};
  8. use socket2::{SockRef, TcpKeepalive};
  9. use tokio::net::{lookup_host, TcpStream, ToSocketAddrs, UdpSocket};
  10. use tracing::trace;
  11. // Tokio hesitates to expose this option...So we have to do it on our own :(
  12. // The good news is that using socket2 it can be easily done, without losing portability.
  13. // See https://github.com/tokio-rs/tokio/issues/3082
  14. pub fn try_set_tcp_keepalive(
  15. conn: &TcpStream,
  16. keepalive_duration: Duration,
  17. keepalive_interval: Duration,
  18. ) -> Result<()> {
  19. let s = SockRef::from(conn);
  20. let keepalive = TcpKeepalive::new()
  21. .with_time(keepalive_duration)
  22. .with_interval(keepalive_interval);
  23. trace!(
  24. "Set TCP keepalive {:?} {:?}",
  25. keepalive_duration,
  26. keepalive_interval
  27. );
  28. Ok(s.set_tcp_keepalive(&keepalive)?)
  29. }
  30. #[allow(dead_code)]
  31. pub fn feature_not_compile(feature: &str) -> ! {
  32. panic!(
  33. "The feature '{}' is not compiled in this binary. Please re-compile rathole",
  34. feature
  35. )
  36. }
  37. /// Create a UDP socket and connect to `addr`
  38. pub async fn udp_connect<A: ToSocketAddrs>(addr: A) -> Result<UdpSocket> {
  39. let addr = lookup_host(addr)
  40. .await?
  41. .next()
  42. .ok_or(anyhow!("Failed to lookup the host"))?;
  43. let bind_addr = match addr {
  44. SocketAddr::V4(_) => "0.0.0.0:0",
  45. SocketAddr::V6(_) => ":::0",
  46. };
  47. let s = UdpSocket::bind(bind_addr).await?;
  48. s.connect(addr).await?;
  49. Ok(s)
  50. }
  51. // FIXME: These functions are for the load balance for UDP. But not used for now.
  52. #[allow(dead_code)]
  53. pub fn hash_socket_addr(a: &SocketAddr) -> u64 {
  54. let mut hasher = DefaultHasher::new();
  55. a.hash(&mut hasher);
  56. hasher.finish()
  57. }
  58. // Wait for the stabilization of https://doc.rust-lang.org/std/primitive.i64.html#method.log2
  59. #[allow(dead_code)]
  60. fn log2_floor(x: usize) -> u8 {
  61. (x as f64).log2().floor() as u8
  62. }
  63. #[allow(dead_code)]
  64. pub fn floor_to_pow_of_2(x: usize) -> usize {
  65. if x == 1 {
  66. return 1;
  67. }
  68. let w = log2_floor(x);
  69. 1 << w
  70. }
  71. #[cfg(test)]
  72. mod test {
  73. use crate::helper::{floor_to_pow_of_2, log2_floor};
  74. #[test]
  75. fn test_log2_floor() {
  76. let t = [
  77. (2, 1),
  78. (3, 1),
  79. (4, 2),
  80. (8, 3),
  81. (9, 3),
  82. (15, 3),
  83. (16, 4),
  84. (1023, 9),
  85. (1024, 10),
  86. (2000, 10),
  87. (2048, 11),
  88. ];
  89. for t in t {
  90. assert_eq!(log2_floor(t.0), t.1);
  91. }
  92. }
  93. #[test]
  94. fn test_floor_to_pow_of_2() {
  95. let t = [
  96. (1 as usize, 1 as usize),
  97. (2, 2),
  98. (3, 2),
  99. (4, 4),
  100. (5, 4),
  101. (15, 8),
  102. (31, 16),
  103. (33, 32),
  104. (1000, 512),
  105. (1500, 1024),
  106. (2300, 2048),
  107. ];
  108. for t in t {
  109. assert_eq!(floor_to_pow_of_2(t.0), t.1);
  110. }
  111. }
  112. }