Răsfoiți Sursa

Add unit tests

Stephen Tierney 2 săptămâni în urmă
părinte
comite
3c06f795eb
1 a modificat fișierele cu 119 adăugiri și 0 ștergeri
  1. 119 0
      src/helper.rs

+ 119 - 0
src/helper.rs

@@ -208,3 +208,122 @@ pub fn generate_proxy_protocol_v1_header(s: &TcpStream) -> Result<String> {
     );
     Ok(header)
 }
+
+#[cfg(test)]
+mod proxy_protocol_tests {
+    use super::generate_proxy_protocol_v1_header;
+    use tokio::io::{AsyncReadExt, AsyncWriteExt};
+    use tokio::net::{TcpListener, TcpStream};
+
+    fn expected_v1_header(local: std::net::SocketAddr, remote: std::net::SocketAddr) -> String {
+        let proto = if local.is_ipv4() { "TCP4" } else { "TCP6" };
+        format!(
+            "PROXY {proto} {} {} {} {}\r\n",
+            remote.ip(),
+            local.ip(),
+            remote.port(),
+            local.port()
+        )
+    }
+
+    #[tokio::test]
+    async fn v1_header_ipv4_format_is_correct() {
+        let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
+        let addr = listener.local_addr().unwrap();
+
+        // Create a connection so we get a real TcpStream with real peer/local addrs
+        let _client = TcpStream::connect(addr).await.unwrap();
+        let (server, _) = listener.accept().await.unwrap();
+
+        let local = server.local_addr().unwrap();
+        let remote = server.peer_addr().unwrap();
+
+        assert!(local.is_ipv4());
+        assert!(remote.is_ipv4());
+
+        let expected = expected_v1_header(local, remote);
+        let got = generate_proxy_protocol_v1_header(&server).unwrap();
+
+        assert_eq!(got, expected);
+        assert!(got.ends_with("\r\n"));
+        assert!(got.starts_with("PROXY TCP4 "));
+    }
+
+    #[tokio::test]
+    async fn v1_header_ipv6_format_is_correct_or_skipped_if_unavailable() {
+        // Some CI environments don’t have IPv6 loopback enabled; skip gracefully.
+        let listener = match TcpListener::bind("[::1]:0").await {
+            Ok(l) => l,
+            Err(_) => return,
+        };
+        let addr = listener.local_addr().unwrap();
+
+        let _client = match TcpStream::connect(addr).await {
+            Ok(c) => c,
+            Err(_) => return,
+        };
+        let (server, _) = listener.accept().await.unwrap();
+
+        let local = server.local_addr().unwrap();
+        let remote = server.peer_addr().unwrap();
+
+        assert!(local.is_ipv6());
+        assert!(remote.is_ipv6());
+
+        let expected = expected_v1_header(local, remote);
+        let got = generate_proxy_protocol_v1_header(&server).unwrap();
+
+        assert_eq!(got, expected);
+        assert!(got.ends_with("\r\n"));
+        assert!(got.starts_with("PROXY TCP6 "));
+    }
+
+    #[tokio::test]
+    async fn header_bytes_are_sent_before_payload_when_written_then_forwarded() {
+        // This simulates the exact ordering your server code implements:
+        // write PROXY header -> flush -> start forwarding bytes. :contentReference[oaicite:1]{index=1}
+
+        // Visitor side connection (incoming connection to the server)
+        let visitor_listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
+        let visitor_addr = visitor_listener.local_addr().unwrap();
+
+        let mut visitor_client = TcpStream::connect(visitor_addr).await.unwrap();
+        let (mut visitor_server, _) = visitor_listener.accept().await.unwrap();
+
+        let local = visitor_server.local_addr().unwrap();
+        let remote = visitor_server.peer_addr().unwrap();
+        let expected_header = expected_v1_header(local, remote);
+
+        let payload = b"hello proxy protocol";
+
+        // Simulate the “data channel” stream with a duplex pipe.
+        // (One end is what the server writes into; the other end is what the downstream reads.)
+        let (mut ch, mut downstream) = tokio::io::duplex(4096);
+
+        let server_task = tokio::spawn(async move {
+            let header = generate_proxy_protocol_v1_header(&visitor_server).unwrap();
+            ch.write_all(header.as_bytes()).await.unwrap();
+            ch.flush().await.unwrap();
+
+            tokio::io::copy_bidirectional(&mut visitor_server, &mut ch)
+                .await
+                .unwrap();
+        });
+
+        // Visitor sends payload
+        visitor_client.write_all(payload).await.unwrap();
+        visitor_client.shutdown().await.unwrap();
+
+        // Downstream should see header first, then payload
+        let mut buf = vec![0u8; expected_header.as_bytes().len() + payload.len()];
+        downstream.read_exact(&mut buf).await.unwrap();
+
+        let header_len = expected_header.as_bytes().len();
+        assert_eq!(&buf[..header_len], expected_header.as_bytes());
+        assert_eq!(&buf[header_len..], payload);
+
+        // Close downstream to let copy_bidirectional finish cleanly
+        drop(downstream);
+        server_task.await.unwrap();
+    }
+}