helper.rs 4.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146
  1. use anyhow::{anyhow, Result};
  2. use async_http_proxy::{http_connect_tokio, http_connect_tokio_with_basic_auth};
  3. use backoff::{backoff::Backoff, Notify};
  4. use socket2::{SockRef, TcpKeepalive};
  5. use std::{future::Future, net::SocketAddr, time::Duration};
  6. use tokio::{
  7. net::{lookup_host, TcpStream, ToSocketAddrs, UdpSocket},
  8. sync::broadcast,
  9. };
  10. use tracing::trace;
  11. use url::Url;
  12. use crate::transport::AddrMaybeCached;
  13. // Tokio hesitates to expose this option...So we have to do it on our own :(
  14. // The good news is that using socket2 it can be easily done, without losing portability.
  15. // See https://github.com/tokio-rs/tokio/issues/3082
  16. pub fn try_set_tcp_keepalive(
  17. conn: &TcpStream,
  18. keepalive_duration: Duration,
  19. keepalive_interval: Duration,
  20. ) -> Result<()> {
  21. let s = SockRef::from(conn);
  22. let keepalive = TcpKeepalive::new()
  23. .with_time(keepalive_duration)
  24. .with_interval(keepalive_interval);
  25. trace!(
  26. "Set TCP keepalive {:?} {:?}",
  27. keepalive_duration,
  28. keepalive_interval
  29. );
  30. Ok(s.set_tcp_keepalive(&keepalive)?)
  31. }
  32. #[allow(dead_code)]
  33. pub fn feature_not_compile(feature: &str) -> ! {
  34. panic!(
  35. "The feature '{}' is not compiled in this binary. Please re-compile rathole",
  36. feature
  37. )
  38. }
  39. pub async fn to_socket_addr<A: ToSocketAddrs>(addr: A) -> Result<SocketAddr> {
  40. lookup_host(addr)
  41. .await?
  42. .next()
  43. .ok_or_else(|| anyhow!("Failed to lookup the host"))
  44. }
  45. pub fn host_port_pair(s: &str) -> Result<(&str, u16)> {
  46. let semi = s.rfind(':').expect("missing semicolon");
  47. Ok((&s[..semi], s[semi + 1..].parse()?))
  48. }
  49. /// Create a UDP socket and connect to `addr`
  50. pub async fn udp_connect<A: ToSocketAddrs>(addr: A) -> Result<UdpSocket> {
  51. let addr = to_socket_addr(addr).await?;
  52. let bind_addr = match addr {
  53. SocketAddr::V4(_) => "0.0.0.0:0",
  54. SocketAddr::V6(_) => ":::0",
  55. };
  56. let s = UdpSocket::bind(bind_addr).await?;
  57. s.connect(addr).await?;
  58. Ok(s)
  59. }
  60. /// Create a TcpStream using a proxy
  61. /// e.g. socks5://user:pass@127.0.0.1:1080 http://127.0.0.1:8080
  62. pub async fn tcp_connect_with_proxy(
  63. addr: &AddrMaybeCached,
  64. proxy: Option<&Url>,
  65. ) -> Result<TcpStream> {
  66. if let Some(url) = proxy {
  67. let addr = &addr.addr;
  68. let mut s = TcpStream::connect((
  69. url.host_str().expect("proxy url should have host field"),
  70. url.port().expect("proxy url should have port field"),
  71. ))
  72. .await?;
  73. let auth = if !url.username().is_empty() || url.password().is_some() {
  74. Some(async_socks5::Auth {
  75. username: url.username().into(),
  76. password: url.password().unwrap_or("").into(),
  77. })
  78. } else {
  79. None
  80. };
  81. match url.scheme() {
  82. "socks5" => {
  83. async_socks5::connect(&mut s, host_port_pair(addr)?, auth).await?;
  84. }
  85. "http" => {
  86. let (host, port) = host_port_pair(addr)?;
  87. match auth {
  88. Some(auth) => {
  89. http_connect_tokio_with_basic_auth(
  90. &mut s,
  91. host,
  92. port,
  93. &auth.username,
  94. &auth.password,
  95. )
  96. .await?
  97. }
  98. None => http_connect_tokio(&mut s, host, port).await?,
  99. }
  100. }
  101. _ => panic!("unknown proxy scheme"),
  102. }
  103. Ok(s)
  104. } else {
  105. Ok(match addr.socket_addr {
  106. Some(s) => TcpStream::connect(s).await?,
  107. None => TcpStream::connect(&addr.addr).await?,
  108. })
  109. }
  110. }
  111. // Wrapper of retry_notify
  112. pub async fn retry_notify_with_deadline<I, E, Fn, Fut, B, N>(
  113. backoff: B,
  114. operation: Fn,
  115. notify: N,
  116. deadline: &mut broadcast::Receiver<bool>,
  117. ) -> Result<I>
  118. where
  119. E: std::error::Error + Send + Sync + 'static,
  120. B: Backoff,
  121. Fn: FnMut() -> Fut,
  122. Fut: Future<Output = std::result::Result<I, backoff::Error<E>>>,
  123. N: Notify<E>,
  124. {
  125. tokio::select! {
  126. v = backoff::future::retry_notify(backoff, operation, notify) => {
  127. v.map_err(anyhow::Error::new)
  128. }
  129. _ = deadline.recv() => {
  130. Err(anyhow!("shutdown"))
  131. }
  132. }
  133. }