integration_test.rs 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386
  1. use anyhow::{Ok, Result};
  2. use common::{run_rathole_client, PING, PONG};
  3. use rand::Rng;
  4. use std::time::Duration;
  5. use tokio::{
  6. io::{AsyncBufReadExt, AsyncReadExt, AsyncWriteExt, BufReader},
  7. net::{TcpStream, UdpSocket},
  8. sync::broadcast,
  9. time,
  10. };
  11. use tracing::{debug, info, instrument};
  12. use tracing_subscriber::EnvFilter;
  13. use crate::common::run_rathole_server;
  14. mod common;
  15. const ECHO_SERVER_ADDR: &str = "127.0.0.1:8080";
  16. const PINGPONG_SERVER_ADDR: &str = "127.0.0.1:8081";
  17. const ECHO_SERVER_ADDR_EXPOSED: &str = "127.0.0.1:2334";
  18. const PINGPONG_SERVER_ADDR_EXPOSED: &str = "127.0.0.1:2335";
  19. const HITTER_NUM: usize = 4;
  20. #[derive(Clone, Copy, Debug)]
  21. enum Type {
  22. Tcp,
  23. Udp,
  24. }
  25. fn init() {
  26. let level = "info";
  27. let _ = tracing_subscriber::fmt()
  28. .with_env_filter(
  29. EnvFilter::try_from_default_env().unwrap_or_else(|_| EnvFilter::from(level)),
  30. )
  31. .try_init();
  32. }
  33. #[tokio::test]
  34. async fn tcp() -> Result<()> {
  35. init();
  36. // Spawn a echo server
  37. tokio::spawn(async move {
  38. if let Err(e) = common::tcp::echo_server(ECHO_SERVER_ADDR).await {
  39. panic!("Failed to run the echo server for testing: {:?}", e);
  40. }
  41. });
  42. // Spawn a pingpong server
  43. tokio::spawn(async move {
  44. if let Err(e) = common::tcp::pingpong_server(PINGPONG_SERVER_ADDR).await {
  45. panic!("Failed to run the pingpong server for testing: {:?}", e);
  46. }
  47. });
  48. test("tests/for_tcp/tcp_transport.toml", Type::Tcp).await?;
  49. test_proxy_protocol("tests/for_tcp/tcp_transport_proxy_protocol.toml").await?;
  50. #[cfg(any(
  51. // FIXME: Self-signed certificate on macOS nativetls requires manual interference.
  52. all(target_os = "macos", feature = "rustls"),
  53. // On other OS accept run with either
  54. all(not(target_os = "macos"), any(feature = "native-tls", feature = "rustls")),
  55. ))]
  56. test("tests/for_tcp/tls_transport.toml", Type::Tcp).await?;
  57. #[cfg(feature = "noise")]
  58. test("tests/for_tcp/noise_transport.toml", Type::Tcp).await?;
  59. #[cfg(any(feature = "websocket-native-tls", feature = "websocket-rustls"))]
  60. test("tests/for_tcp/websocket_transport.toml", Type::Tcp).await?;
  61. #[cfg(not(target_os = "macos"))]
  62. #[cfg(any(feature = "websocket-native-tls", feature = "websocket-rustls"))]
  63. test("tests/for_tcp/websocket_tls_transport.toml", Type::Tcp).await?;
  64. Ok(())
  65. }
  66. #[tokio::test]
  67. async fn udp() -> Result<()> {
  68. init();
  69. // Spawn a echo server
  70. tokio::spawn(async move {
  71. if let Err(e) = common::udp::echo_server(ECHO_SERVER_ADDR).await {
  72. panic!("Failed to run the echo server for testing: {:?}", e);
  73. }
  74. });
  75. // Spawn a pingpong server
  76. tokio::spawn(async move {
  77. if let Err(e) = common::udp::pingpong_server(PINGPONG_SERVER_ADDR).await {
  78. panic!("Failed to run the pingpong server for testing: {:?}", e);
  79. }
  80. });
  81. test("tests/for_udp/tcp_transport.toml", Type::Udp).await?;
  82. #[cfg(any(
  83. // FIXME: Self-signed certificate on macOS nativetls requires manual interference.
  84. all(target_os = "macos", feature = "rustls"),
  85. // On other OS accept run with either
  86. all(not(target_os = "macos"), any(feature = "native-tls", feature = "rustls")),
  87. ))]
  88. test("tests/for_udp/tls_transport.toml", Type::Udp).await?;
  89. #[cfg(feature = "noise")]
  90. test("tests/for_udp/noise_transport.toml", Type::Udp).await?;
  91. #[cfg(any(feature = "websocket-native-tls", feature = "websocket-rustls"))]
  92. test("tests/for_udp/websocket_transport.toml", Type::Udp).await?;
  93. #[cfg(not(target_os = "macos"))]
  94. #[cfg(any(feature = "websocket-native-tls", feature = "websocket-rustls"))]
  95. test("tests/for_udp/websocket_tls_transport.toml", Type::Udp).await?;
  96. Ok(())
  97. }
  98. #[instrument]
  99. async fn test(config_path: &'static str, t: Type) -> Result<()> {
  100. if cfg!(not(all(feature = "client", feature = "server"))) {
  101. // Skip the test if the client or the server is not enabled
  102. return Ok(());
  103. }
  104. let (client_shutdown_tx, client_shutdown_rx) = broadcast::channel(1);
  105. let (server_shutdown_tx, server_shutdown_rx) = broadcast::channel(1);
  106. // Start the client
  107. info!("start the client");
  108. let client = tokio::spawn(async move {
  109. run_rathole_client(config_path, client_shutdown_rx)
  110. .await
  111. .unwrap();
  112. });
  113. // Sleep for 1 second. Expect the client keep retrying to reach the server
  114. time::sleep(Duration::from_secs(1)).await;
  115. // Start the server
  116. info!("start the server");
  117. let server = tokio::spawn(async move {
  118. run_rathole_server(config_path, server_shutdown_rx)
  119. .await
  120. .unwrap();
  121. });
  122. time::sleep(Duration::from_millis(2500)).await; // Wait for the client to retry
  123. info!("echo");
  124. echo_hitter(ECHO_SERVER_ADDR_EXPOSED, t).await.unwrap();
  125. info!("pingpong");
  126. pingpong_hitter(PINGPONG_SERVER_ADDR_EXPOSED, t)
  127. .await
  128. .unwrap();
  129. // Simulate the client crash and restart
  130. info!("shutdown the client");
  131. client_shutdown_tx.send(true)?;
  132. let _ = tokio::join!(client);
  133. info!("restart the client");
  134. let client_shutdown_rx = client_shutdown_tx.subscribe();
  135. let client = tokio::spawn(async move {
  136. run_rathole_client(config_path, client_shutdown_rx)
  137. .await
  138. .unwrap();
  139. });
  140. time::sleep(Duration::from_secs(1)).await; // Wait for the client to start
  141. info!("echo");
  142. echo_hitter(ECHO_SERVER_ADDR_EXPOSED, t).await.unwrap();
  143. info!("pingpong");
  144. pingpong_hitter(PINGPONG_SERVER_ADDR_EXPOSED, t)
  145. .await
  146. .unwrap();
  147. // Simulate the server crash and restart
  148. info!("shutdown the server");
  149. server_shutdown_tx.send(true)?;
  150. let _ = tokio::join!(server);
  151. info!("restart the server");
  152. let server_shutdown_rx = server_shutdown_tx.subscribe();
  153. let server = tokio::spawn(async move {
  154. run_rathole_server(config_path, server_shutdown_rx)
  155. .await
  156. .unwrap();
  157. });
  158. time::sleep(Duration::from_millis(2500)).await; // Wait for the client to retry
  159. // Simulate heavy load
  160. info!("lots of echo and pingpong");
  161. let mut v = Vec::new();
  162. for _ in 0..HITTER_NUM / 2 {
  163. v.push(tokio::spawn(async move {
  164. echo_hitter(ECHO_SERVER_ADDR_EXPOSED, t).await.unwrap();
  165. }));
  166. v.push(tokio::spawn(async move {
  167. pingpong_hitter(PINGPONG_SERVER_ADDR_EXPOSED, t)
  168. .await
  169. .unwrap();
  170. }));
  171. }
  172. for h in v {
  173. assert!(tokio::join!(h).0.is_ok());
  174. }
  175. // Shutdown
  176. info!("shutdown the server and the client");
  177. server_shutdown_tx.send(true)?;
  178. client_shutdown_tx.send(true)?;
  179. let _ = tokio::join!(server, client);
  180. Ok(())
  181. }
  182. async fn echo_hitter(addr: &'static str, t: Type) -> Result<()> {
  183. match t {
  184. Type::Tcp => tcp_echo_hitter(addr).await,
  185. Type::Udp => udp_echo_hitter(addr).await,
  186. }
  187. }
  188. async fn pingpong_hitter(addr: &'static str, t: Type) -> Result<()> {
  189. match t {
  190. Type::Tcp => tcp_pingpong_hitter(addr).await,
  191. Type::Udp => udp_pingpong_hitter(addr).await,
  192. }
  193. }
  194. async fn tcp_echo_hitter(addr: &'static str) -> Result<()> {
  195. let mut conn = TcpStream::connect(addr).await?;
  196. let mut wr = [0u8; 1024];
  197. let mut rd = [0u8; 1024];
  198. for _ in 0..100 {
  199. rand::thread_rng().fill(&mut wr);
  200. conn.write_all(&wr).await?;
  201. conn.read_exact(&mut rd).await?;
  202. assert_eq!(wr, rd);
  203. }
  204. Ok(())
  205. }
  206. async fn udp_echo_hitter(addr: &'static str) -> Result<()> {
  207. let conn = UdpSocket::bind("127.0.0.1:0").await?;
  208. conn.connect(addr).await?;
  209. let mut wr = [0u8; 128];
  210. let mut rd = [0u8; 128];
  211. for _ in 0..3 {
  212. rand::thread_rng().fill(&mut wr);
  213. conn.send(&wr).await?;
  214. debug!("send");
  215. conn.recv(&mut rd).await?;
  216. debug!("recv");
  217. assert_eq!(wr, rd);
  218. }
  219. Ok(())
  220. }
  221. async fn tcp_pingpong_hitter(addr: &'static str) -> Result<()> {
  222. let mut conn = TcpStream::connect(addr).await?;
  223. let wr = PING.as_bytes();
  224. let mut rd = [0u8; PONG.len()];
  225. for _ in 0..100 {
  226. conn.write_all(wr).await?;
  227. conn.read_exact(&mut rd).await?;
  228. assert_eq!(rd, PONG.as_bytes());
  229. }
  230. Ok(())
  231. }
  232. async fn udp_pingpong_hitter(addr: &'static str) -> Result<()> {
  233. let conn = UdpSocket::bind("127.0.0.1:0").await?;
  234. conn.connect(&addr).await?;
  235. let wr = PING.as_bytes();
  236. let mut rd = [0u8; PONG.len()];
  237. for _ in 0..3 {
  238. conn.send(wr).await?;
  239. debug!("ping");
  240. conn.recv(&mut rd).await?;
  241. debug!("pong");
  242. assert_eq!(rd, PONG.as_bytes());
  243. }
  244. Ok(())
  245. }
  246. #[instrument]
  247. async fn test_proxy_protocol(config_path: &'static str) -> Result<()> {
  248. if cfg!(not(all(feature = "client", feature = "server"))) {
  249. return Ok(());
  250. }
  251. let (client_shutdown_tx, client_shutdown_rx) = broadcast::channel(1);
  252. let (server_shutdown_tx, server_shutdown_rx) = broadcast::channel(1);
  253. info!("start the client");
  254. let client = tokio::spawn(async move {
  255. run_rathole_client(config_path, client_shutdown_rx)
  256. .await
  257. .unwrap();
  258. });
  259. time::sleep(Duration::from_secs(1)).await;
  260. info!("start the server");
  261. let server = tokio::spawn(async move {
  262. run_rathole_server(config_path, server_shutdown_rx)
  263. .await
  264. .unwrap();
  265. });
  266. time::sleep(Duration::from_millis(2500)).await;
  267. info!("echo");
  268. tcp_echo_hitter_expect_proxy_protocol(ECHO_SERVER_ADDR_EXPOSED).await?;
  269. info!("pingpong )");
  270. tcp_pingpong_hitter(PINGPONG_SERVER_ADDR_EXPOSED).await?;
  271. info!("shutdown the server and the client");
  272. server_shutdown_tx.send(true)?;
  273. client_shutdown_tx.send(true)?;
  274. let _ = tokio::join!(server, client);
  275. Ok(())
  276. }
  277. async fn tcp_echo_hitter_expect_proxy_protocol(addr: &'static str) -> Result<()> {
  278. let conn = TcpStream::connect(addr).await?;
  279. let local = conn.local_addr()?;
  280. let peer = conn.peer_addr()?;
  281. let (rd, mut wr) = conn.into_split();
  282. let mut rd = BufReader::new(rd);
  283. // Read the echoed PROXY header line first.
  284. let mut header = String::new();
  285. let n = time::timeout(Duration::from_secs(5), rd.read_line(&mut header)).await??;
  286. assert!(n > 0, "expected a proxy protocol header line");
  287. let expected = format!(
  288. "PROXY TCP4 {} {} {} {}\r\n",
  289. local.ip(),
  290. peer.ip(),
  291. local.port(),
  292. peer.port()
  293. );
  294. assert_eq!(header, expected);
  295. // Now the stream should behave like a normal echo connection.
  296. let mut wr_buf = [0u8; 1024];
  297. let mut rd_buf = [0u8; 1024];
  298. for _ in 0..100 {
  299. rand::thread_rng().fill(&mut wr_buf);
  300. wr.write_all(&wr_buf).await?;
  301. rd.read_exact(&mut rd_buf).await?;
  302. assert_eq!(wr_buf, rd_buf);
  303. }
  304. Ok(())
  305. }