|
@@ -208,3 +208,122 @@ pub fn generate_proxy_protocol_v1_header(s: &TcpStream) -> Result<String> {
|
|
|
);
|
|
);
|
|
|
Ok(header)
|
|
Ok(header)
|
|
|
}
|
|
}
|
|
|
|
|
+
|
|
|
|
|
+#[cfg(test)]
|
|
|
|
|
+mod proxy_protocol_tests {
|
|
|
|
|
+ use super::generate_proxy_protocol_v1_header;
|
|
|
|
|
+ use tokio::io::{AsyncReadExt, AsyncWriteExt};
|
|
|
|
|
+ use tokio::net::{TcpListener, TcpStream};
|
|
|
|
|
+
|
|
|
|
|
+ fn expected_v1_header(local: std::net::SocketAddr, remote: std::net::SocketAddr) -> String {
|
|
|
|
|
+ let proto = if local.is_ipv4() { "TCP4" } else { "TCP6" };
|
|
|
|
|
+ format!(
|
|
|
|
|
+ "PROXY {proto} {} {} {} {}\r\n",
|
|
|
|
|
+ remote.ip(),
|
|
|
|
|
+ local.ip(),
|
|
|
|
|
+ remote.port(),
|
|
|
|
|
+ local.port()
|
|
|
|
|
+ )
|
|
|
|
|
+ }
|
|
|
|
|
+
|
|
|
|
|
+ #[tokio::test]
|
|
|
|
|
+ async fn v1_header_ipv4_format_is_correct() {
|
|
|
|
|
+ let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
|
|
|
|
|
+ let addr = listener.local_addr().unwrap();
|
|
|
|
|
+
|
|
|
|
|
+ // Create a connection so we get a real TcpStream with real peer/local addrs
|
|
|
|
|
+ let _client = TcpStream::connect(addr).await.unwrap();
|
|
|
|
|
+ let (server, _) = listener.accept().await.unwrap();
|
|
|
|
|
+
|
|
|
|
|
+ let local = server.local_addr().unwrap();
|
|
|
|
|
+ let remote = server.peer_addr().unwrap();
|
|
|
|
|
+
|
|
|
|
|
+ assert!(local.is_ipv4());
|
|
|
|
|
+ assert!(remote.is_ipv4());
|
|
|
|
|
+
|
|
|
|
|
+ let expected = expected_v1_header(local, remote);
|
|
|
|
|
+ let got = generate_proxy_protocol_v1_header(&server).unwrap();
|
|
|
|
|
+
|
|
|
|
|
+ assert_eq!(got, expected);
|
|
|
|
|
+ assert!(got.ends_with("\r\n"));
|
|
|
|
|
+ assert!(got.starts_with("PROXY TCP4 "));
|
|
|
|
|
+ }
|
|
|
|
|
+
|
|
|
|
|
+ #[tokio::test]
|
|
|
|
|
+ async fn v1_header_ipv6_format_is_correct_or_skipped_if_unavailable() {
|
|
|
|
|
+ // Some CI environments don’t have IPv6 loopback enabled; skip gracefully.
|
|
|
|
|
+ let listener = match TcpListener::bind("[::1]:0").await {
|
|
|
|
|
+ Ok(l) => l,
|
|
|
|
|
+ Err(_) => return,
|
|
|
|
|
+ };
|
|
|
|
|
+ let addr = listener.local_addr().unwrap();
|
|
|
|
|
+
|
|
|
|
|
+ let _client = match TcpStream::connect(addr).await {
|
|
|
|
|
+ Ok(c) => c,
|
|
|
|
|
+ Err(_) => return,
|
|
|
|
|
+ };
|
|
|
|
|
+ let (server, _) = listener.accept().await.unwrap();
|
|
|
|
|
+
|
|
|
|
|
+ let local = server.local_addr().unwrap();
|
|
|
|
|
+ let remote = server.peer_addr().unwrap();
|
|
|
|
|
+
|
|
|
|
|
+ assert!(local.is_ipv6());
|
|
|
|
|
+ assert!(remote.is_ipv6());
|
|
|
|
|
+
|
|
|
|
|
+ let expected = expected_v1_header(local, remote);
|
|
|
|
|
+ let got = generate_proxy_protocol_v1_header(&server).unwrap();
|
|
|
|
|
+
|
|
|
|
|
+ assert_eq!(got, expected);
|
|
|
|
|
+ assert!(got.ends_with("\r\n"));
|
|
|
|
|
+ assert!(got.starts_with("PROXY TCP6 "));
|
|
|
|
|
+ }
|
|
|
|
|
+
|
|
|
|
|
+ #[tokio::test]
|
|
|
|
|
+ async fn header_bytes_are_sent_before_payload_when_written_then_forwarded() {
|
|
|
|
|
+ // This simulates the exact ordering your server code implements:
|
|
|
|
|
+ // write PROXY header -> flush -> start forwarding bytes. :contentReference[oaicite:1]{index=1}
|
|
|
|
|
+
|
|
|
|
|
+ // Visitor side connection (incoming connection to the server)
|
|
|
|
|
+ let visitor_listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
|
|
|
|
|
+ let visitor_addr = visitor_listener.local_addr().unwrap();
|
|
|
|
|
+
|
|
|
|
|
+ let mut visitor_client = TcpStream::connect(visitor_addr).await.unwrap();
|
|
|
|
|
+ let (mut visitor_server, _) = visitor_listener.accept().await.unwrap();
|
|
|
|
|
+
|
|
|
|
|
+ let local = visitor_server.local_addr().unwrap();
|
|
|
|
|
+ let remote = visitor_server.peer_addr().unwrap();
|
|
|
|
|
+ let expected_header = expected_v1_header(local, remote);
|
|
|
|
|
+
|
|
|
|
|
+ let payload = b"hello proxy protocol";
|
|
|
|
|
+
|
|
|
|
|
+ // Simulate the “data channel” stream with a duplex pipe.
|
|
|
|
|
+ // (One end is what the server writes into; the other end is what the downstream reads.)
|
|
|
|
|
+ let (mut ch, mut downstream) = tokio::io::duplex(4096);
|
|
|
|
|
+
|
|
|
|
|
+ let server_task = tokio::spawn(async move {
|
|
|
|
|
+ let header = generate_proxy_protocol_v1_header(&visitor_server).unwrap();
|
|
|
|
|
+ ch.write_all(header.as_bytes()).await.unwrap();
|
|
|
|
|
+ ch.flush().await.unwrap();
|
|
|
|
|
+
|
|
|
|
|
+ tokio::io::copy_bidirectional(&mut visitor_server, &mut ch)
|
|
|
|
|
+ .await
|
|
|
|
|
+ .unwrap();
|
|
|
|
|
+ });
|
|
|
|
|
+
|
|
|
|
|
+ // Visitor sends payload
|
|
|
|
|
+ visitor_client.write_all(payload).await.unwrap();
|
|
|
|
|
+ visitor_client.shutdown().await.unwrap();
|
|
|
|
|
+
|
|
|
|
|
+ // Downstream should see header first, then payload
|
|
|
|
|
+ let mut buf = vec![0u8; expected_header.as_bytes().len() + payload.len()];
|
|
|
|
|
+ downstream.read_exact(&mut buf).await.unwrap();
|
|
|
|
|
+
|
|
|
|
|
+ let header_len = expected_header.as_bytes().len();
|
|
|
|
|
+ assert_eq!(&buf[..header_len], expected_header.as_bytes());
|
|
|
|
|
+ assert_eq!(&buf[header_len..], payload);
|
|
|
|
|
+
|
|
|
|
|
+ // Close downstream to let copy_bidirectional finish cleanly
|
|
|
|
|
+ drop(downstream);
|
|
|
|
|
+ server_task.await.unwrap();
|
|
|
|
|
+ }
|
|
|
|
|
+}
|