Răsfoiți Sursa

feature: Import proxy-protocol v2 changes from pkarc/proxy-protocol (#455)

* feature: Import proxy-protocol changes from pkarc/proxy-protocol

* Accept cargo clippy changes

* Fix integration tests for proxy protocol v2

* Update docs
Stephen Tierney 2 săptămâni în urmă
părinte
comite
0f41051dcf

+ 1 - 0
README.md

@@ -171,6 +171,7 @@ type = "tcp" # Optional. Same as the client `[client.services.X.type]
 token = "whatever" # Necessary if `server.default_token` not set
 bind_addr = "0.0.0.0:8081" # Necessary. The address of the service is exposed at. Generally only the port needs to be change.
 nodelay = true # Optional. Same as the client
+proxy_protocol = "v2" # Optional. Prepend HAProxy PROXY protocol header to each incoming TCP connection before forwarding to the client. Possible values: ["v1", "v2"]. Default: disabled (unset). Only applies to TCP services.
 
 [server.services.service2]
 bind_addr = "0.0.0.1:8082"

+ 7 - 5
examples/proxy_protocol/server.toml

@@ -1,7 +1,9 @@
-# rathole configuration for proxy protocol enabled client
-# 
-# The service configuration has an additional `enable_proxy_protocol` boolean field.
-# Not setting this field defaults its value to `false` at runtime.
+# rathole configuration with Proxy Protocol enabled.
+#
+# Each service can optionally set `proxy_protocol` to:
+#   - "v1" (human-readable line)
+#   - "v2" (binary header)
+# If omitted, Proxy Protocol is disabled for that service.
 
 [server]
 bind_addr = "0.0.0.0:2333"
@@ -9,4 +11,4 @@ default_token = "123"
 
 [server.services.foo1]
 bind_addr = "0.0.0.0:5202"
-enable_proxy_protocol = true
+proxy_protocol = "v2"

+ 1 - 1
src/config.rs

@@ -104,7 +104,7 @@ pub struct ServerServiceConfig {
     pub bind_addr: String,
     pub token: Option<MaskedString>,
     pub nodelay: Option<bool>,
-    pub enable_proxy_protocol: Option<bool>,
+    pub proxy_protocol: Option<String>,
 }
 
 impl ServerServiceConfig {

+ 177 - 42
src/helper.rs

@@ -194,28 +194,74 @@ where
     Ok(())
 }
 
-pub fn generate_proxy_protocol_v1_header(s: &TcpStream) -> Result<String> {
+pub fn generate_proxy_protocol_header(s: &TcpStream, proxy_protocol: &str) -> Result<Vec<u8>, anyhow::Error> {
     let local_addr = s.local_addr()?;
     let remote_addr = s.peer_addr()?;
-    let proto = if local_addr.is_ipv4() { "TCP4" } else { "TCP6" };
-    let header = format!(
-        "PROXY {} {} {} {} {}\r\n", 
-        proto, 
-        remote_addr.ip(), 
-        local_addr.ip(), 
-        remote_addr.port(), 
-        local_addr.port()
-    );
-    Ok(header)
+
+    match proxy_protocol {
+        "v1" => {
+            let proto = if local_addr.is_ipv4() { "TCP4" } else { "TCP6" };
+            let header = format!(
+                "PROXY {} {} {} {} {}\r\n", 
+                proto, 
+                remote_addr.ip(), 
+                local_addr.ip(), 
+                remote_addr.port(), 
+                local_addr.port()
+            );
+
+            Ok(header.into_bytes())
+        }
+        "v2" => {
+
+            let v2sig: &[u8] = &[0x0D, 0x0A, 0x0D, 0x0A, 0x00, 0x0D, 0x0A, 0x51, 0x55, 0x49, 0x54, 0x0A];
+            let ver_cmd = &[0x21]; // 0x21 version 2 and PROXY command
+            let proto = if local_addr.is_ipv4() { &[0x11] } else { &[0x21] }; // 0x11 for TCP IPv4 and 0x21 for TCP IPv6, TODO: support UNIX
+            let addrs_length: &[u8] = if local_addr.is_ipv4() { &[0, 12] } else { &[0, 36] }; // 12 for IPv4 and 36 for IPv6, TOOD: support UNIX
+            let src_addr = match remote_addr {
+                SocketAddr::V4(v4) => v4.ip().octets().to_vec(),
+                SocketAddr::V6(v6) => v6.ip().octets().to_vec(),
+            };
+            let dst_addr = match local_addr {
+                SocketAddr::V4(v4) => v4.ip().octets().to_vec(),
+                SocketAddr::V6(v6) => v6.ip().octets().to_vec(),
+            };
+    
+            let header:Vec<u8> = [
+                v2sig, 
+                ver_cmd, 
+                proto, 
+                addrs_length,
+                &src_addr,
+                &dst_addr,
+                &remote_addr.port().to_be_bytes(),
+                &local_addr.port().to_be_bytes()
+                ].concat();
+    
+            trace!("Proxy protocol v2 header: {:02x?}", header);
+    
+            Ok(header)
+
+        },
+        _ => {
+            Err(anyhow!("Unknown proxy protocol {}", proxy_protocol))
+        }
+    }
+
 }
 
 #[cfg(test)]
 mod proxy_protocol_tests {
-    use super::generate_proxy_protocol_v1_header;
+    use super::generate_proxy_protocol_header;
+    use std::net::{IpAddr, SocketAddr};
     use tokio::io::{AsyncReadExt, AsyncWriteExt};
     use tokio::net::{TcpListener, TcpStream};
 
-    fn expected_v1_header(local: std::net::SocketAddr, remote: std::net::SocketAddr) -> String {
+    const V2_SIG: [u8; 12] = [
+        0x0D, 0x0A, 0x0D, 0x0A, 0x00, 0x0D, 0x0A, 0x51, 0x55, 0x49, 0x54, 0x0A,
+    ];
+
+    fn expected_v1_header(local: SocketAddr, remote: SocketAddr) -> Vec<u8> {
         let proto = if local.is_ipv4() { "TCP4" } else { "TCP6" };
         format!(
             "PROXY {proto} {} {} {} {}\r\n",
@@ -224,6 +270,34 @@ mod proxy_protocol_tests {
             remote.port(),
             local.port()
         )
+        .into_bytes()
+    }
+
+    fn expected_v2_header(local: SocketAddr, remote: SocketAddr) -> Vec<u8> {
+        let mut out = Vec::new();
+        out.extend_from_slice(&V2_SIG);
+        out.push(0x21); // v2 + PROXY command
+
+        match (remote.ip(), local.ip()) {
+            (IpAddr::V4(src), IpAddr::V4(dst)) => {
+                out.push(0x11); // AF_INET (0x1) + STREAM (0x1) => 0x11
+                out.extend_from_slice(&[0x00, 0x0c]); // len = 12
+                out.extend_from_slice(&src.octets());
+                out.extend_from_slice(&dst.octets());
+            }
+            (IpAddr::V6(src), IpAddr::V6(dst)) => {
+                out.push(0x21); // AF_INET6 (0x2) + STREAM (0x1) => 0x21
+                out.extend_from_slice(&[0x00, 0x24]); // len = 36
+                out.extend_from_slice(&src.octets());
+                out.extend_from_slice(&dst.octets());
+            }
+            _ => panic!("mismatched address families in test"),
+        }
+
+        // src port then dst port
+        out.extend_from_slice(&remote.port().to_be_bytes());
+        out.extend_from_slice(&local.port().to_be_bytes());
+        out
     }
 
     #[tokio::test]
@@ -231,27 +305,50 @@ mod proxy_protocol_tests {
         let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
         let addr = listener.local_addr().unwrap();
 
-        // Create a connection so we get a real TcpStream with real peer/local addrs
         let _client = TcpStream::connect(addr).await.unwrap();
         let (server, _) = listener.accept().await.unwrap();
 
         let local = server.local_addr().unwrap();
         let remote = server.peer_addr().unwrap();
-
         assert!(local.is_ipv4());
         assert!(remote.is_ipv4());
 
         let expected = expected_v1_header(local, remote);
-        let got = generate_proxy_protocol_v1_header(&server).unwrap();
+        let got = generate_proxy_protocol_header(&server, "v1").unwrap();
 
         assert_eq!(got, expected);
-        assert!(got.ends_with("\r\n"));
-        assert!(got.starts_with("PROXY TCP4 "));
+        assert!(got.ends_with(b"\r\n"));
+        assert!(got.starts_with(b"PROXY TCP4 "));
+    }
+
+    #[tokio::test]
+    async fn v2_header_ipv4_format_is_correct() {
+        let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
+        let addr = listener.local_addr().unwrap();
+
+        let _client = TcpStream::connect(addr).await.unwrap();
+        let (server, _) = listener.accept().await.unwrap();
+
+        let local = server.local_addr().unwrap();
+        let remote = server.peer_addr().unwrap();
+        assert!(local.is_ipv4());
+        assert!(remote.is_ipv4());
+
+        let expected = expected_v2_header(local, remote);
+        let got = generate_proxy_protocol_header(&server, "v2").unwrap();
+
+        assert_eq!(got, expected);
+
+        // Spot-check fixed fields and sizes
+        assert_eq!(&got[..12], &V2_SIG);
+        assert_eq!(got[12], 0x21);
+        assert_eq!(got[13], 0x11);
+        assert_eq!(&got[14..16], &[0x00, 0x0c]);
+        assert_eq!(got.len(), 28);
     }
 
     #[tokio::test]
     async fn v1_header_ipv6_format_is_correct_or_skipped_if_unavailable() {
-        // Some CI environments don’t have IPv6 loopback enabled; skip gracefully.
         let listener = match TcpListener::bind("[::1]:0").await {
             Ok(l) => l,
             Err(_) => return,
@@ -266,43 +363,74 @@ mod proxy_protocol_tests {
 
         let local = server.local_addr().unwrap();
         let remote = server.peer_addr().unwrap();
-
         assert!(local.is_ipv6());
         assert!(remote.is_ipv6());
 
         let expected = expected_v1_header(local, remote);
-        let got = generate_proxy_protocol_v1_header(&server).unwrap();
+        let got = generate_proxy_protocol_header(&server, "v1").unwrap();
 
         assert_eq!(got, expected);
-        assert!(got.ends_with("\r\n"));
-        assert!(got.starts_with("PROXY TCP6 "));
+        assert!(got.ends_with(b"\r\n"));
+        assert!(got.starts_with(b"PROXY TCP6 "));
     }
 
     #[tokio::test]
-    async fn header_bytes_are_sent_before_payload_when_written_then_forwarded() {
-        // This simulates the exact ordering your server code implements:
-        // write PROXY header -> flush -> start forwarding bytes. :contentReference[oaicite:1]{index=1}
+    async fn v2_header_ipv6_format_is_correct_or_skipped_if_unavailable() {
+        let listener = match TcpListener::bind("[::1]:0").await {
+            Ok(l) => l,
+            Err(_) => return,
+        };
+        let addr = listener.local_addr().unwrap();
+
+        let _client = match TcpStream::connect(addr).await {
+            Ok(c) => c,
+            Err(_) => return,
+        };
+        let (server, _) = listener.accept().await.unwrap();
+
+        let local = server.local_addr().unwrap();
+        let remote = server.peer_addr().unwrap();
+        assert!(local.is_ipv6());
+        assert!(remote.is_ipv6());
+
+        let expected = expected_v2_header(local, remote);
+        let got = generate_proxy_protocol_header(&server, "v2").unwrap();
+
+        assert_eq!(got, expected);
+        assert_eq!(&got[..12], &V2_SIG);
+        assert_eq!(got[12], 0x21);
+        assert_eq!(got[13], 0x21);
+        assert_eq!(&got[14..16], &[0x00, 0x24]);
+        assert_eq!(got.len(), 52);
+    }
+
+    #[tokio::test]
+    async fn unknown_proxy_protocol_is_rejected() {
+        let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
+        let addr = listener.local_addr().unwrap();
+
+        let _client = TcpStream::connect(addr).await.unwrap();
+        let (server, _) = listener.accept().await.unwrap();
+
+        let err = generate_proxy_protocol_header(&server, "nope").unwrap_err();
+        assert!(err.to_string().contains("Unknown proxy protocol"));
+    }
 
-        // Visitor side connection (incoming connection to the server)
+    async fn header_is_sent_before_payload(version: &'static str) {
         let visitor_listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
         let visitor_addr = visitor_listener.local_addr().unwrap();
 
         let mut visitor_client = TcpStream::connect(visitor_addr).await.unwrap();
         let (mut visitor_server, _) = visitor_listener.accept().await.unwrap();
 
-        let local = visitor_server.local_addr().unwrap();
-        let remote = visitor_server.peer_addr().unwrap();
-        let expected_header = expected_v1_header(local, remote);
-
+        let expected_header = generate_proxy_protocol_header(&visitor_server, version).unwrap();
         let payload = b"hello proxy protocol";
 
-        // Simulate the “data channel” stream with a duplex pipe.
-        // (One end is what the server writes into; the other end is what the downstream reads.)
         let (mut ch, mut downstream) = tokio::io::duplex(4096);
 
         let server_task = tokio::spawn(async move {
-            let header = generate_proxy_protocol_v1_header(&visitor_server).unwrap();
-            ch.write_all(header.as_bytes()).await.unwrap();
+            let header = generate_proxy_protocol_header(&visitor_server, version).unwrap();
+            ch.write_all(&header).await.unwrap();
             ch.flush().await.unwrap();
 
             tokio::io::copy_bidirectional(&mut visitor_server, &mut ch)
@@ -310,20 +438,27 @@ mod proxy_protocol_tests {
                 .unwrap();
         });
 
-        // Visitor sends payload
         visitor_client.write_all(payload).await.unwrap();
         visitor_client.shutdown().await.unwrap();
 
-        // Downstream should see header first, then payload
-        let mut buf = vec![0u8; expected_header.as_bytes().len() + payload.len()];
+        let mut buf = vec![0u8; expected_header.len() + payload.len()];
         downstream.read_exact(&mut buf).await.unwrap();
 
-        let header_len = expected_header.as_bytes().len();
-        assert_eq!(&buf[..header_len], expected_header.as_bytes());
+        let header_len = expected_header.len();
+        assert_eq!(&buf[..header_len], &expected_header);
         assert_eq!(&buf[header_len..], payload);
 
-        // Close downstream to let copy_bidirectional finish cleanly
         drop(downstream);
         server_task.await.unwrap();
     }
-}
+
+    #[tokio::test]
+    async fn v1_header_bytes_are_sent_before_payload_when_forwarding() {
+        header_is_sent_before_payload("v1").await;
+    }
+
+    #[tokio::test]
+    async fn v2_header_bytes_are_sent_before_payload_when_forwarding() {
+        header_is_sent_before_payload("v2").await;
+    }
+}

+ 22 - 10
src/server.rs

@@ -1,7 +1,7 @@
 use crate::config::{Config, ServerConfig, ServerServiceConfig, ServiceType, TransportType};
 use crate::config_watcher::{ConfigChange, ServerServiceChange};
 use crate::constants::{listen_backoff, UDP_BUFFER_SIZE};
-use crate::helper::{generate_proxy_protocol_v1_header, retry_notify_with_deadline, write_and_flush};
+use crate::helper::{generate_proxy_protocol_header, retry_notify_with_deadline, write_and_flush};
 use crate::multi_map::MultiMap;
 use crate::protocol::Hello::{ControlChannelHello, DataChannelHello};
 use crate::protocol::{
@@ -427,16 +427,20 @@ where
 
         let shutdown_rx_clone = shutdown_tx.subscribe();
         let bind_addr = service.bind_addr.clone();
-        let enable_proxy_protocol = service.enable_proxy_protocol.unwrap_or_default();
-        if enable_proxy_protocol {
-            debug!("Proxy protocol is enabled");
+        let proxy_protocol = service.proxy_protocol.clone().unwrap_or_default();
+        if proxy_protocol == "v1" || proxy_protocol == "v2" {
+            info!("Proxy protocol {:?} is enabled", proxy_protocol);
+        } else if proxy_protocol.is_empty() {
+            info!("Proxy protocol is disabled");
+        } else {
+            error!("Unknown proxy protocol {}", proxy_protocol);
         }
         match service.service_type {
             ServiceType::Tcp => tokio::spawn(
                 async move {
                     if let Err(e) = run_tcp_connection_pool::<T>(
                         bind_addr,
-                        enable_proxy_protocol,
+                        proxy_protocol.clone(),
                         data_ch_rx,
                         data_ch_req_tx,
                         shutdown_rx_clone,
@@ -630,7 +634,7 @@ fn tcp_listen_and_send(
 #[instrument(skip_all)]
 async fn run_tcp_connection_pool<T: Transport>(
     bind_addr: String,
-    enable_proxy_protocol: bool,
+    proxy_protocol: String,
     mut data_ch_rx: mpsc::Receiver<T::Stream>,
     data_ch_req_tx: mpsc::UnboundedSender<bool>,
     shutdown_rx: broadcast::Receiver<bool>,
@@ -642,11 +646,19 @@ async fn run_tcp_connection_pool<T: Transport>(
         loop {
             if let Some(mut ch) = data_ch_rx.recv().await {
                 if write_and_flush(&mut ch, &cmd).await.is_ok() {
+                    let proxy_proto = proxy_protocol.clone();
                     tokio::spawn(async move {
-                        if enable_proxy_protocol {
-                            let proxy_proto_header = generate_proxy_protocol_v1_header(&visitor).unwrap();
-                            let _ = ch.write_all(&proxy_proto_header.into_bytes()).await;
-                            let _ = ch.flush().await;
+                        if !proxy_proto.is_empty() {
+                            let proxy_proto_header = generate_proxy_protocol_header(&visitor, &proxy_proto);
+                            match proxy_proto_header {
+                                Ok(header) => {
+                                    let _ = ch.write_all(&header).await;
+                                    let _ = ch.flush().await;
+                                },
+                                Err(e) => {
+                                    error!("Failed to generate proxy protocol header: {}", e);
+                                }
+                            }
                         }
                         let _ = copy_bidirectional(&mut ch, &mut visitor).await;
                     });

+ 1 - 1
tests/for_tcp/tcp_transport_proxy_protocol.toml → tests/for_tcp/tcp_transport_proxy_protocol_v1.toml

@@ -19,7 +19,7 @@ type = "tcp"
 
 [server.services.echo]
 bind_addr = "0.0.0.0:2334"
-enable_proxy_protocol = true
+proxy_protocol = "v1"
 
 [server.services.pingpong]
 bind_addr = "0.0.0.0:2335"

+ 25 - 0
tests/for_tcp/tcp_transport_proxy_protocol_v2.toml

@@ -0,0 +1,25 @@
+[client]
+remote_addr = "127.0.0.1:2333"
+default_token = "default_token_if_not_specify"
+
+[client.transport]
+type = "tcp"
+
+[client.services.echo]
+local_addr = "127.0.0.1:8080"
+[client.services.pingpong]
+local_addr = "127.0.0.1:8081"
+
+[server]
+bind_addr = "0.0.0.0:2333"
+default_token = "default_token_if_not_specify"
+
+[server.transport]
+type = "tcp"
+
+[server.services.echo]
+bind_addr = "0.0.0.0:2334"
+proxy_protocol = "v2"
+
+[server.services.pingpong]
+bind_addr = "0.0.0.0:2335"

+ 116 - 17
tests/integration_test.rs

@@ -1,6 +1,8 @@
-use anyhow::{Ok, Result};
+use anyhow::{anyhow, Ok, Result};
 use common::{run_rathole_client, PING, PONG};
 use rand::Rng;
+use rand::RngCore;
+use std::net::{IpAddr, Ipv4Addr, Ipv6Addr, SocketAddr};
 use std::time::Duration;
 use tokio::{
     io::{AsyncBufReadExt, AsyncReadExt, AsyncWriteExt, BufReader},
@@ -21,6 +23,11 @@ const ECHO_SERVER_ADDR_EXPOSED: &str = "127.0.0.1:2334";
 const PINGPONG_SERVER_ADDR_EXPOSED: &str = "127.0.0.1:2335";
 const HITTER_NUM: usize = 4;
 
+const PP2_SIG: [u8; 12] = [
+    0x0D, 0x0A, 0x0D, 0x0A, 0x00, 0x0D, 0x0A, 0x51, 0x55, 0x49, 0x54, 0x0A,
+];
+
+
 #[derive(Clone, Copy, Debug)]
 enum Type {
     Tcp,
@@ -56,7 +63,8 @@ async fn tcp() -> Result<()> {
 
     test("tests/for_tcp/tcp_transport.toml", Type::Tcp).await?;
 
-    test_proxy_protocol("tests/for_tcp/tcp_transport_proxy_protocol.toml").await?;
+    test_proxy_protocol("tests/for_tcp/tcp_transport_proxy_protocol_v1.toml").await?;
+    test_proxy_protocol("tests/for_tcp/tcp_transport_proxy_protocol_v2.toml").await?;
 
     #[cfg(any(
          // FIXME: Self-signed certificate on macOS nativetls requires manual interference.
@@ -349,6 +357,86 @@ async fn test_proxy_protocol(config_path: &'static str) -> Result<()> {
     Ok(())
 }
 
+async fn read_proxy_protocol_header(rd: &mut BufReader<tokio::net::tcp::OwnedReadHalf>) -> Result<Vec<u8>> {
+    // Read 12 bytes to distinguish v2 signature vs v1 ("PROXY ...")
+    let mut first12 = [0u8; 12];
+    time::timeout(Duration::from_secs(5), rd.read_exact(&mut first12)).await??;
+
+    if first12 == PP2_SIG {
+        // v2: read fixed header (ver/cmd, fam/proto, len[2]) then read len bytes
+        let mut fixed = [0u8; 4];
+        time::timeout(Duration::from_secs(5), rd.read_exact(&mut fixed)).await??;
+
+        let len = u16::from_be_bytes([fixed[2], fixed[3]]) as usize;
+        let mut addr_and_tlvs = vec![0u8; len];
+        time::timeout(Duration::from_secs(5), rd.read_exact(&mut addr_and_tlvs)).await??;
+
+        let mut out = Vec::with_capacity(16 + len);
+        out.extend_from_slice(&first12);
+        out.extend_from_slice(&fixed);
+        out.extend_from_slice(&addr_and_tlvs);
+        Ok(out)
+    } else {
+        // v1: we've already consumed 12 bytes; read until newline to complete the line
+        let mut out = first12.to_vec();
+        let n = time::timeout(Duration::from_secs(5), rd.read_until(b'\n', &mut out)).await??;
+        if n == 0 {
+            return Err(anyhow!("EOF while reading proxy protocol v1 line"));
+        }
+        Ok(out)
+    }
+}
+
+fn assert_proxy_v2_matches(header: &[u8], local: SocketAddr, peer: SocketAddr) {
+    assert!(header.len() >= 16);
+    assert_eq!(&header[..12], &PP2_SIG);
+
+    // version/command
+    assert_eq!(header[12], 0x21, "expected v2 PROXY command (0x21)");
+
+    let fam_proto = header[13];
+    let len = u16::from_be_bytes([header[14], header[15]]) as usize;
+    assert_eq!(header.len(), 16 + len, "v2 length mismatch");
+
+    match fam_proto {
+        0x11 => {
+            // INET + STREAM, minimum 12 bytes address block
+            assert!(len >= 12);
+
+            let src = IpAddr::V4(Ipv4Addr::new(header[16], header[17], header[18], header[19]));
+            let dst = IpAddr::V4(Ipv4Addr::new(header[20], header[21], header[22], header[23]));
+            let src_port = u16::from_be_bytes([header[24], header[25]]);
+            let dst_port = u16::from_be_bytes([header[26], header[27]]);
+
+            assert_eq!(src, local.ip());
+            assert_eq!(dst, peer.ip());
+            assert_eq!(src_port, local.port());
+            assert_eq!(dst_port, peer.port());
+        }
+        0x21 => {
+            // INET6 + STREAM, minimum 36 bytes address block
+            assert!(len >= 36);
+
+            let mut src_oct = [0u8; 16];
+            let mut dst_oct = [0u8; 16];
+            src_oct.copy_from_slice(&header[16..32]);
+            dst_oct.copy_from_slice(&header[32..48]);
+
+            let src = IpAddr::V6(Ipv6Addr::from(src_oct));
+            let dst = IpAddr::V6(Ipv6Addr::from(dst_oct));
+            let src_port = u16::from_be_bytes([header[48], header[49]]);
+            let dst_port = u16::from_be_bytes([header[50], header[51]]);
+
+            assert_eq!(src, local.ip());
+            assert_eq!(dst, peer.ip());
+            assert_eq!(src_port, local.port());
+            assert_eq!(dst_port, peer.port());
+        }
+        other => panic!("unexpected v2 fam/proto byte: {other:#x}"),
+    }
+}
+
+
 async fn tcp_echo_hitter_expect_proxy_protocol(addr: &'static str) -> Result<()> {
     let conn = TcpStream::connect(addr).await?;
     let local = conn.local_addr()?;
@@ -357,30 +445,41 @@ async fn tcp_echo_hitter_expect_proxy_protocol(addr: &'static str) -> Result<()>
     let (rd, mut wr) = conn.into_split();
     let mut rd = BufReader::new(rd);
 
-    // Read the echoed PROXY header line first.
-    let mut header = String::new();
-    let n = time::timeout(Duration::from_secs(5), rd.read_line(&mut header)).await??;
-    assert!(n > 0, "expected a proxy protocol header line");
-
-    let expected = format!(
-        "PROXY TCP4 {} {} {} {}\r\n",
-        local.ip(),
-        peer.ip(),
-        local.port(),
-        peer.port()
-    );
-    assert_eq!(header, expected);
+    // Read & validate proxy protocol header (v1 or v2)
+    let header = read_proxy_protocol_header(&mut rd).await?;
+
+    if header.starts_with(b"PROXY ") {
+        // v1 assertion (stringy)
+        let proto = if local.is_ipv4() { "TCP4" } else { "TCP6" };
+        let expected = format!(
+            "PROXY {proto} {} {} {} {}\r\n",
+            local.ip(),
+            peer.ip(),
+            local.port(),
+            peer.port()
+        )
+        .into_bytes();
+        assert_eq!(header, expected);
+    } else {
+        // v2 assertion (binary)
+        assert_proxy_v2_matches(&header, local, peer);
+    }
 
     // Now the stream should behave like a normal echo connection.
     let mut wr_buf = [0u8; 1024];
     let mut rd_buf = [0u8; 1024];
 
     for _ in 0..100 {
-        rand::thread_rng().fill(&mut wr_buf);
+        rand::thread_rng().fill_bytes(&mut wr_buf);
         wr.write_all(&wr_buf).await?;
         rd.read_exact(&mut rd_buf).await?;
         assert_eq!(wr_buf, rd_buf);
     }
 
     Ok(())
-}
+}
+
+
+
+
+