helper.rs 4.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137
  1. use anyhow::{anyhow, Result};
  2. use async_http_proxy::{http_connect_tokio, http_connect_tokio_with_basic_auth};
  3. use backoff::{backoff::Backoff, Notify};
  4. use socket2::{SockRef, TcpKeepalive};
  5. use std::{future::Future, net::SocketAddr, time::Duration};
  6. use tokio::{
  7. net::{lookup_host, TcpStream, ToSocketAddrs, UdpSocket},
  8. sync::broadcast,
  9. };
  10. use tracing::trace;
  11. use url::Url;
  12. // Tokio hesitates to expose this option...So we have to do it on our own :(
  13. // The good news is that using socket2 it can be easily done, without losing portability.
  14. // See https://github.com/tokio-rs/tokio/issues/3082
  15. pub fn try_set_tcp_keepalive(
  16. conn: &TcpStream,
  17. keepalive_duration: Duration,
  18. keepalive_interval: Duration,
  19. ) -> Result<()> {
  20. let s = SockRef::from(conn);
  21. let keepalive = TcpKeepalive::new()
  22. .with_time(keepalive_duration)
  23. .with_interval(keepalive_interval);
  24. trace!(
  25. "Set TCP keepalive {:?} {:?}",
  26. keepalive_duration,
  27. keepalive_interval
  28. );
  29. Ok(s.set_tcp_keepalive(&keepalive)?)
  30. }
  31. #[allow(dead_code)]
  32. pub fn feature_not_compile(feature: &str) -> ! {
  33. panic!(
  34. "The feature '{}' is not compiled in this binary. Please re-compile rathole",
  35. feature
  36. )
  37. }
  38. async fn to_socket_addr<A: ToSocketAddrs>(addr: A) -> Result<SocketAddr> {
  39. lookup_host(addr)
  40. .await?
  41. .next()
  42. .ok_or(anyhow!("Failed to lookup the host"))
  43. }
  44. pub fn host_port_pair(s: &str) -> Result<(&str, u16)> {
  45. let semi = s.rfind(':').expect("missing semicolon");
  46. Ok((&s[..semi], s[semi + 1..].parse()?))
  47. }
  48. /// Create a UDP socket and connect to `addr`
  49. pub async fn udp_connect<A: ToSocketAddrs>(addr: A) -> Result<UdpSocket> {
  50. let addr = to_socket_addr(addr).await?;
  51. let bind_addr = match addr {
  52. SocketAddr::V4(_) => "0.0.0.0:0",
  53. SocketAddr::V6(_) => ":::0",
  54. };
  55. let s = UdpSocket::bind(bind_addr).await?;
  56. s.connect(addr).await?;
  57. Ok(s)
  58. }
  59. /// Create a TcpStream using a proxy
  60. /// e.g. socks5://user:pass@127.0.0.1:1080 http://127.0.0.1:8080
  61. pub async fn tcp_connect_with_proxy(addr: &str, proxy: Option<&Url>) -> Result<TcpStream> {
  62. if let Some(url) = proxy {
  63. let mut s = TcpStream::connect((
  64. url.host_str().expect("proxy url should have host field"),
  65. url.port().expect("proxy url should have port field"),
  66. ))
  67. .await?;
  68. let auth = if !url.username().is_empty() || url.password().is_some() {
  69. Some(async_socks5::Auth {
  70. username: url.username().into(),
  71. password: url.password().unwrap_or("").into(),
  72. })
  73. } else {
  74. None
  75. };
  76. match url.scheme() {
  77. "socks5" => {
  78. async_socks5::connect(&mut s, host_port_pair(addr)?, auth).await?;
  79. }
  80. "http" => {
  81. let (host, port) = host_port_pair(addr)?;
  82. match auth {
  83. Some(auth) => {
  84. http_connect_tokio_with_basic_auth(
  85. &mut s,
  86. host,
  87. port,
  88. &auth.username,
  89. &auth.password,
  90. )
  91. .await?
  92. }
  93. None => http_connect_tokio(&mut s, host, port).await?,
  94. }
  95. }
  96. _ => panic!("unknown proxy scheme"),
  97. }
  98. Ok(s)
  99. } else {
  100. Ok(TcpStream::connect(addr).await?)
  101. }
  102. }
  103. // Wrapper of retry_notify
  104. pub async fn retry_notify_with_deadline<I, E, Fn, Fut, B, N>(
  105. backoff: B,
  106. operation: Fn,
  107. notify: N,
  108. deadline: &mut broadcast::Receiver<bool>,
  109. ) -> Result<I>
  110. where
  111. E: std::error::Error + Send + Sync + 'static,
  112. B: Backoff,
  113. Fn: FnMut() -> Fut,
  114. Fut: Future<Output = std::result::Result<I, backoff::Error<E>>>,
  115. N: Notify<E>,
  116. {
  117. tokio::select! {
  118. v = backoff::future::retry_notify(backoff, operation, notify) => {
  119. v.map_err(anyhow::Error::new)
  120. }
  121. _ = deadline.recv() => {
  122. Err(anyhow!("shutdown"))
  123. }
  124. }
  125. }