integration_test.rs 14 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485
  1. use anyhow::{anyhow, Ok, Result};
  2. use common::{run_rathole_client, PING, PONG};
  3. use rand::Rng;
  4. use rand::RngCore;
  5. use std::net::{IpAddr, Ipv4Addr, Ipv6Addr, SocketAddr};
  6. use std::time::Duration;
  7. use tokio::{
  8. io::{AsyncBufReadExt, AsyncReadExt, AsyncWriteExt, BufReader},
  9. net::{TcpStream, UdpSocket},
  10. sync::broadcast,
  11. time,
  12. };
  13. use tracing::{debug, info, instrument};
  14. use tracing_subscriber::EnvFilter;
  15. use crate::common::run_rathole_server;
  16. mod common;
  17. const ECHO_SERVER_ADDR: &str = "127.0.0.1:8080";
  18. const PINGPONG_SERVER_ADDR: &str = "127.0.0.1:8081";
  19. const ECHO_SERVER_ADDR_EXPOSED: &str = "127.0.0.1:2334";
  20. const PINGPONG_SERVER_ADDR_EXPOSED: &str = "127.0.0.1:2335";
  21. const HITTER_NUM: usize = 4;
  22. const PP2_SIG: [u8; 12] = [
  23. 0x0D, 0x0A, 0x0D, 0x0A, 0x00, 0x0D, 0x0A, 0x51, 0x55, 0x49, 0x54, 0x0A,
  24. ];
  25. #[derive(Clone, Copy, Debug)]
  26. enum Type {
  27. Tcp,
  28. Udp,
  29. }
  30. fn init() {
  31. let level = "info";
  32. let _ = tracing_subscriber::fmt()
  33. .with_env_filter(
  34. EnvFilter::try_from_default_env().unwrap_or_else(|_| EnvFilter::from(level)),
  35. )
  36. .try_init();
  37. }
  38. #[tokio::test]
  39. async fn tcp() -> Result<()> {
  40. init();
  41. // Spawn a echo server
  42. tokio::spawn(async move {
  43. if let Err(e) = common::tcp::echo_server(ECHO_SERVER_ADDR).await {
  44. panic!("Failed to run the echo server for testing: {:?}", e);
  45. }
  46. });
  47. // Spawn a pingpong server
  48. tokio::spawn(async move {
  49. if let Err(e) = common::tcp::pingpong_server(PINGPONG_SERVER_ADDR).await {
  50. panic!("Failed to run the pingpong server for testing: {:?}", e);
  51. }
  52. });
  53. test("tests/for_tcp/tcp_transport.toml", Type::Tcp).await?;
  54. test_proxy_protocol("tests/for_tcp/tcp_transport_proxy_protocol_v1.toml").await?;
  55. test_proxy_protocol("tests/for_tcp/tcp_transport_proxy_protocol_v2.toml").await?;
  56. #[cfg(any(
  57. // FIXME: Self-signed certificate on macOS nativetls requires manual interference.
  58. all(target_os = "macos", feature = "rustls"),
  59. // On other OS accept run with either
  60. all(not(target_os = "macos"), any(feature = "native-tls", feature = "rustls")),
  61. ))]
  62. test("tests/for_tcp/tls_transport.toml", Type::Tcp).await?;
  63. #[cfg(feature = "noise")]
  64. test("tests/for_tcp/noise_transport.toml", Type::Tcp).await?;
  65. #[cfg(any(feature = "websocket-native-tls", feature = "websocket-rustls"))]
  66. test("tests/for_tcp/websocket_transport.toml", Type::Tcp).await?;
  67. #[cfg(not(target_os = "macos"))]
  68. #[cfg(any(feature = "websocket-native-tls", feature = "websocket-rustls"))]
  69. test("tests/for_tcp/websocket_tls_transport.toml", Type::Tcp).await?;
  70. Ok(())
  71. }
  72. #[tokio::test]
  73. async fn udp() -> Result<()> {
  74. init();
  75. // Spawn a echo server
  76. tokio::spawn(async move {
  77. if let Err(e) = common::udp::echo_server(ECHO_SERVER_ADDR).await {
  78. panic!("Failed to run the echo server for testing: {:?}", e);
  79. }
  80. });
  81. // Spawn a pingpong server
  82. tokio::spawn(async move {
  83. if let Err(e) = common::udp::pingpong_server(PINGPONG_SERVER_ADDR).await {
  84. panic!("Failed to run the pingpong server for testing: {:?}", e);
  85. }
  86. });
  87. test("tests/for_udp/tcp_transport.toml", Type::Udp).await?;
  88. #[cfg(any(
  89. // FIXME: Self-signed certificate on macOS nativetls requires manual interference.
  90. all(target_os = "macos", feature = "rustls"),
  91. // On other OS accept run with either
  92. all(not(target_os = "macos"), any(feature = "native-tls", feature = "rustls")),
  93. ))]
  94. test("tests/for_udp/tls_transport.toml", Type::Udp).await?;
  95. #[cfg(feature = "noise")]
  96. test("tests/for_udp/noise_transport.toml", Type::Udp).await?;
  97. #[cfg(any(feature = "websocket-native-tls", feature = "websocket-rustls"))]
  98. test("tests/for_udp/websocket_transport.toml", Type::Udp).await?;
  99. #[cfg(not(target_os = "macos"))]
  100. #[cfg(any(feature = "websocket-native-tls", feature = "websocket-rustls"))]
  101. test("tests/for_udp/websocket_tls_transport.toml", Type::Udp).await?;
  102. Ok(())
  103. }
  104. #[instrument]
  105. async fn test(config_path: &'static str, t: Type) -> Result<()> {
  106. if cfg!(not(all(feature = "client", feature = "server"))) {
  107. // Skip the test if the client or the server is not enabled
  108. return Ok(());
  109. }
  110. let (client_shutdown_tx, client_shutdown_rx) = broadcast::channel(1);
  111. let (server_shutdown_tx, server_shutdown_rx) = broadcast::channel(1);
  112. // Start the client
  113. info!("start the client");
  114. let client = tokio::spawn(async move {
  115. run_rathole_client(config_path, client_shutdown_rx)
  116. .await
  117. .unwrap();
  118. });
  119. // Sleep for 1 second. Expect the client keep retrying to reach the server
  120. time::sleep(Duration::from_secs(1)).await;
  121. // Start the server
  122. info!("start the server");
  123. let server = tokio::spawn(async move {
  124. run_rathole_server(config_path, server_shutdown_rx)
  125. .await
  126. .unwrap();
  127. });
  128. time::sleep(Duration::from_millis(2500)).await; // Wait for the client to retry
  129. info!("echo");
  130. echo_hitter(ECHO_SERVER_ADDR_EXPOSED, t).await.unwrap();
  131. info!("pingpong");
  132. pingpong_hitter(PINGPONG_SERVER_ADDR_EXPOSED, t)
  133. .await
  134. .unwrap();
  135. // Simulate the client crash and restart
  136. info!("shutdown the client");
  137. client_shutdown_tx.send(true)?;
  138. let _ = tokio::join!(client);
  139. info!("restart the client");
  140. let client_shutdown_rx = client_shutdown_tx.subscribe();
  141. let client = tokio::spawn(async move {
  142. run_rathole_client(config_path, client_shutdown_rx)
  143. .await
  144. .unwrap();
  145. });
  146. time::sleep(Duration::from_secs(1)).await; // Wait for the client to start
  147. info!("echo");
  148. echo_hitter(ECHO_SERVER_ADDR_EXPOSED, t).await.unwrap();
  149. info!("pingpong");
  150. pingpong_hitter(PINGPONG_SERVER_ADDR_EXPOSED, t)
  151. .await
  152. .unwrap();
  153. // Simulate the server crash and restart
  154. info!("shutdown the server");
  155. server_shutdown_tx.send(true)?;
  156. let _ = tokio::join!(server);
  157. info!("restart the server");
  158. let server_shutdown_rx = server_shutdown_tx.subscribe();
  159. let server = tokio::spawn(async move {
  160. run_rathole_server(config_path, server_shutdown_rx)
  161. .await
  162. .unwrap();
  163. });
  164. time::sleep(Duration::from_millis(2500)).await; // Wait for the client to retry
  165. // Simulate heavy load
  166. info!("lots of echo and pingpong");
  167. let mut v = Vec::new();
  168. for _ in 0..HITTER_NUM / 2 {
  169. v.push(tokio::spawn(async move {
  170. echo_hitter(ECHO_SERVER_ADDR_EXPOSED, t).await.unwrap();
  171. }));
  172. v.push(tokio::spawn(async move {
  173. pingpong_hitter(PINGPONG_SERVER_ADDR_EXPOSED, t)
  174. .await
  175. .unwrap();
  176. }));
  177. }
  178. for h in v {
  179. assert!(tokio::join!(h).0.is_ok());
  180. }
  181. // Shutdown
  182. info!("shutdown the server and the client");
  183. server_shutdown_tx.send(true)?;
  184. client_shutdown_tx.send(true)?;
  185. let _ = tokio::join!(server, client);
  186. Ok(())
  187. }
  188. async fn echo_hitter(addr: &'static str, t: Type) -> Result<()> {
  189. match t {
  190. Type::Tcp => tcp_echo_hitter(addr).await,
  191. Type::Udp => udp_echo_hitter(addr).await,
  192. }
  193. }
  194. async fn pingpong_hitter(addr: &'static str, t: Type) -> Result<()> {
  195. match t {
  196. Type::Tcp => tcp_pingpong_hitter(addr).await,
  197. Type::Udp => udp_pingpong_hitter(addr).await,
  198. }
  199. }
  200. async fn tcp_echo_hitter(addr: &'static str) -> Result<()> {
  201. let mut conn = TcpStream::connect(addr).await?;
  202. let mut wr = [0u8; 1024];
  203. let mut rd = [0u8; 1024];
  204. for _ in 0..100 {
  205. rand::thread_rng().fill(&mut wr);
  206. conn.write_all(&wr).await?;
  207. conn.read_exact(&mut rd).await?;
  208. assert_eq!(wr, rd);
  209. }
  210. Ok(())
  211. }
  212. async fn udp_echo_hitter(addr: &'static str) -> Result<()> {
  213. let conn = UdpSocket::bind("127.0.0.1:0").await?;
  214. conn.connect(addr).await?;
  215. let mut wr = [0u8; 128];
  216. let mut rd = [0u8; 128];
  217. for _ in 0..3 {
  218. rand::thread_rng().fill(&mut wr);
  219. conn.send(&wr).await?;
  220. debug!("send");
  221. conn.recv(&mut rd).await?;
  222. debug!("recv");
  223. assert_eq!(wr, rd);
  224. }
  225. Ok(())
  226. }
  227. async fn tcp_pingpong_hitter(addr: &'static str) -> Result<()> {
  228. let mut conn = TcpStream::connect(addr).await?;
  229. let wr = PING.as_bytes();
  230. let mut rd = [0u8; PONG.len()];
  231. for _ in 0..100 {
  232. conn.write_all(wr).await?;
  233. conn.read_exact(&mut rd).await?;
  234. assert_eq!(rd, PONG.as_bytes());
  235. }
  236. Ok(())
  237. }
  238. async fn udp_pingpong_hitter(addr: &'static str) -> Result<()> {
  239. let conn = UdpSocket::bind("127.0.0.1:0").await?;
  240. conn.connect(&addr).await?;
  241. let wr = PING.as_bytes();
  242. let mut rd = [0u8; PONG.len()];
  243. for _ in 0..3 {
  244. conn.send(wr).await?;
  245. debug!("ping");
  246. conn.recv(&mut rd).await?;
  247. debug!("pong");
  248. assert_eq!(rd, PONG.as_bytes());
  249. }
  250. Ok(())
  251. }
  252. #[instrument]
  253. async fn test_proxy_protocol(config_path: &'static str) -> Result<()> {
  254. if cfg!(not(all(feature = "client", feature = "server"))) {
  255. return Ok(());
  256. }
  257. let (client_shutdown_tx, client_shutdown_rx) = broadcast::channel(1);
  258. let (server_shutdown_tx, server_shutdown_rx) = broadcast::channel(1);
  259. info!("start the client");
  260. let client = tokio::spawn(async move {
  261. run_rathole_client(config_path, client_shutdown_rx)
  262. .await
  263. .unwrap();
  264. });
  265. time::sleep(Duration::from_secs(1)).await;
  266. info!("start the server");
  267. let server = tokio::spawn(async move {
  268. run_rathole_server(config_path, server_shutdown_rx)
  269. .await
  270. .unwrap();
  271. });
  272. time::sleep(Duration::from_millis(2500)).await;
  273. info!("echo");
  274. tcp_echo_hitter_expect_proxy_protocol(ECHO_SERVER_ADDR_EXPOSED).await?;
  275. info!("pingpong )");
  276. tcp_pingpong_hitter(PINGPONG_SERVER_ADDR_EXPOSED).await?;
  277. info!("shutdown the server and the client");
  278. server_shutdown_tx.send(true)?;
  279. client_shutdown_tx.send(true)?;
  280. let _ = tokio::join!(server, client);
  281. Ok(())
  282. }
  283. async fn read_proxy_protocol_header(rd: &mut BufReader<tokio::net::tcp::OwnedReadHalf>) -> Result<Vec<u8>> {
  284. // Read 12 bytes to distinguish v2 signature vs v1 ("PROXY ...")
  285. let mut first12 = [0u8; 12];
  286. time::timeout(Duration::from_secs(5), rd.read_exact(&mut first12)).await??;
  287. if first12 == PP2_SIG {
  288. // v2: read fixed header (ver/cmd, fam/proto, len[2]) then read len bytes
  289. let mut fixed = [0u8; 4];
  290. time::timeout(Duration::from_secs(5), rd.read_exact(&mut fixed)).await??;
  291. let len = u16::from_be_bytes([fixed[2], fixed[3]]) as usize;
  292. let mut addr_and_tlvs = vec![0u8; len];
  293. time::timeout(Duration::from_secs(5), rd.read_exact(&mut addr_and_tlvs)).await??;
  294. let mut out = Vec::with_capacity(16 + len);
  295. out.extend_from_slice(&first12);
  296. out.extend_from_slice(&fixed);
  297. out.extend_from_slice(&addr_and_tlvs);
  298. Ok(out)
  299. } else {
  300. // v1: we've already consumed 12 bytes; read until newline to complete the line
  301. let mut out = first12.to_vec();
  302. let n = time::timeout(Duration::from_secs(5), rd.read_until(b'\n', &mut out)).await??;
  303. if n == 0 {
  304. return Err(anyhow!("EOF while reading proxy protocol v1 line"));
  305. }
  306. Ok(out)
  307. }
  308. }
  309. fn assert_proxy_v2_matches(header: &[u8], local: SocketAddr, peer: SocketAddr) {
  310. assert!(header.len() >= 16);
  311. assert_eq!(&header[..12], &PP2_SIG);
  312. // version/command
  313. assert_eq!(header[12], 0x21, "expected v2 PROXY command (0x21)");
  314. let fam_proto = header[13];
  315. let len = u16::from_be_bytes([header[14], header[15]]) as usize;
  316. assert_eq!(header.len(), 16 + len, "v2 length mismatch");
  317. match fam_proto {
  318. 0x11 => {
  319. // INET + STREAM, minimum 12 bytes address block
  320. assert!(len >= 12);
  321. let src = IpAddr::V4(Ipv4Addr::new(header[16], header[17], header[18], header[19]));
  322. let dst = IpAddr::V4(Ipv4Addr::new(header[20], header[21], header[22], header[23]));
  323. let src_port = u16::from_be_bytes([header[24], header[25]]);
  324. let dst_port = u16::from_be_bytes([header[26], header[27]]);
  325. assert_eq!(src, local.ip());
  326. assert_eq!(dst, peer.ip());
  327. assert_eq!(src_port, local.port());
  328. assert_eq!(dst_port, peer.port());
  329. }
  330. 0x21 => {
  331. // INET6 + STREAM, minimum 36 bytes address block
  332. assert!(len >= 36);
  333. let mut src_oct = [0u8; 16];
  334. let mut dst_oct = [0u8; 16];
  335. src_oct.copy_from_slice(&header[16..32]);
  336. dst_oct.copy_from_slice(&header[32..48]);
  337. let src = IpAddr::V6(Ipv6Addr::from(src_oct));
  338. let dst = IpAddr::V6(Ipv6Addr::from(dst_oct));
  339. let src_port = u16::from_be_bytes([header[48], header[49]]);
  340. let dst_port = u16::from_be_bytes([header[50], header[51]]);
  341. assert_eq!(src, local.ip());
  342. assert_eq!(dst, peer.ip());
  343. assert_eq!(src_port, local.port());
  344. assert_eq!(dst_port, peer.port());
  345. }
  346. other => panic!("unexpected v2 fam/proto byte: {other:#x}"),
  347. }
  348. }
  349. async fn tcp_echo_hitter_expect_proxy_protocol(addr: &'static str) -> Result<()> {
  350. let conn = TcpStream::connect(addr).await?;
  351. let local = conn.local_addr()?;
  352. let peer = conn.peer_addr()?;
  353. let (rd, mut wr) = conn.into_split();
  354. let mut rd = BufReader::new(rd);
  355. // Read & validate proxy protocol header (v1 or v2)
  356. let header = read_proxy_protocol_header(&mut rd).await?;
  357. if header.starts_with(b"PROXY ") {
  358. // v1 assertion (stringy)
  359. let proto = if local.is_ipv4() { "TCP4" } else { "TCP6" };
  360. let expected = format!(
  361. "PROXY {proto} {} {} {} {}\r\n",
  362. local.ip(),
  363. peer.ip(),
  364. local.port(),
  365. peer.port()
  366. )
  367. .into_bytes();
  368. assert_eq!(header, expected);
  369. } else {
  370. // v2 assertion (binary)
  371. assert_proxy_v2_matches(&header, local, peer);
  372. }
  373. // Now the stream should behave like a normal echo connection.
  374. let mut wr_buf = [0u8; 1024];
  375. let mut rd_buf = [0u8; 1024];
  376. for _ in 0..100 {
  377. rand::thread_rng().fill_bytes(&mut wr_buf);
  378. wr.write_all(&wr_buf).await?;
  379. rd.read_exact(&mut rd_buf).await?;
  380. assert_eq!(wr_buf, rd_buf);
  381. }
  382. Ok(())
  383. }