|
|
@@ -194,28 +194,74 @@ where
|
|
|
Ok(())
|
|
|
}
|
|
|
|
|
|
-pub fn generate_proxy_protocol_v1_header(s: &TcpStream) -> Result<String> {
|
|
|
+pub fn generate_proxy_protocol_header(s: &TcpStream, proxy_protocol: &str) -> Result<Vec<u8>, anyhow::Error> {
|
|
|
let local_addr = s.local_addr()?;
|
|
|
let remote_addr = s.peer_addr()?;
|
|
|
- let proto = if local_addr.is_ipv4() { "TCP4" } else { "TCP6" };
|
|
|
- let header = format!(
|
|
|
- "PROXY {} {} {} {} {}\r\n",
|
|
|
- proto,
|
|
|
- remote_addr.ip(),
|
|
|
- local_addr.ip(),
|
|
|
- remote_addr.port(),
|
|
|
- local_addr.port()
|
|
|
- );
|
|
|
- Ok(header)
|
|
|
+
|
|
|
+ match proxy_protocol {
|
|
|
+ "v1" => {
|
|
|
+ let proto = if local_addr.is_ipv4() { "TCP4" } else { "TCP6" };
|
|
|
+ let header = format!(
|
|
|
+ "PROXY {} {} {} {} {}\r\n",
|
|
|
+ proto,
|
|
|
+ remote_addr.ip(),
|
|
|
+ local_addr.ip(),
|
|
|
+ remote_addr.port(),
|
|
|
+ local_addr.port()
|
|
|
+ );
|
|
|
+
|
|
|
+ Ok(header.into_bytes())
|
|
|
+ }
|
|
|
+ "v2" => {
|
|
|
+
|
|
|
+ let v2sig: &[u8] = &[0x0D, 0x0A, 0x0D, 0x0A, 0x00, 0x0D, 0x0A, 0x51, 0x55, 0x49, 0x54, 0x0A];
|
|
|
+ let ver_cmd = &[0x21]; // 0x21 version 2 and PROXY command
|
|
|
+ let proto = if local_addr.is_ipv4() { &[0x11] } else { &[0x21] }; // 0x11 for TCP IPv4 and 0x21 for TCP IPv6, TODO: support UNIX
|
|
|
+ let addrs_length: &[u8] = if local_addr.is_ipv4() { &[0, 12] } else { &[0, 36] }; // 12 for IPv4 and 36 for IPv6, TOOD: support UNIX
|
|
|
+ let src_addr = match remote_addr {
|
|
|
+ SocketAddr::V4(v4) => v4.ip().octets().to_vec(),
|
|
|
+ SocketAddr::V6(v6) => v6.ip().octets().to_vec(),
|
|
|
+ };
|
|
|
+ let dst_addr = match local_addr {
|
|
|
+ SocketAddr::V4(v4) => v4.ip().octets().to_vec(),
|
|
|
+ SocketAddr::V6(v6) => v6.ip().octets().to_vec(),
|
|
|
+ };
|
|
|
+
|
|
|
+ let header:Vec<u8> = [
|
|
|
+ v2sig,
|
|
|
+ ver_cmd,
|
|
|
+ proto,
|
|
|
+ addrs_length,
|
|
|
+ &src_addr,
|
|
|
+ &dst_addr,
|
|
|
+ &remote_addr.port().to_be_bytes(),
|
|
|
+ &local_addr.port().to_be_bytes()
|
|
|
+ ].concat();
|
|
|
+
|
|
|
+ trace!("Proxy protocol v2 header: {:02x?}", header);
|
|
|
+
|
|
|
+ Ok(header)
|
|
|
+
|
|
|
+ },
|
|
|
+ _ => {
|
|
|
+ Err(anyhow!("Unknown proxy protocol {}", proxy_protocol))
|
|
|
+ }
|
|
|
+ }
|
|
|
+
|
|
|
}
|
|
|
|
|
|
#[cfg(test)]
|
|
|
mod proxy_protocol_tests {
|
|
|
- use super::generate_proxy_protocol_v1_header;
|
|
|
+ use super::generate_proxy_protocol_header;
|
|
|
+ use std::net::{IpAddr, SocketAddr};
|
|
|
use tokio::io::{AsyncReadExt, AsyncWriteExt};
|
|
|
use tokio::net::{TcpListener, TcpStream};
|
|
|
|
|
|
- fn expected_v1_header(local: std::net::SocketAddr, remote: std::net::SocketAddr) -> String {
|
|
|
+ const V2_SIG: [u8; 12] = [
|
|
|
+ 0x0D, 0x0A, 0x0D, 0x0A, 0x00, 0x0D, 0x0A, 0x51, 0x55, 0x49, 0x54, 0x0A,
|
|
|
+ ];
|
|
|
+
|
|
|
+ fn expected_v1_header(local: SocketAddr, remote: SocketAddr) -> Vec<u8> {
|
|
|
let proto = if local.is_ipv4() { "TCP4" } else { "TCP6" };
|
|
|
format!(
|
|
|
"PROXY {proto} {} {} {} {}\r\n",
|
|
|
@@ -224,6 +270,34 @@ mod proxy_protocol_tests {
|
|
|
remote.port(),
|
|
|
local.port()
|
|
|
)
|
|
|
+ .into_bytes()
|
|
|
+ }
|
|
|
+
|
|
|
+ fn expected_v2_header(local: SocketAddr, remote: SocketAddr) -> Vec<u8> {
|
|
|
+ let mut out = Vec::new();
|
|
|
+ out.extend_from_slice(&V2_SIG);
|
|
|
+ out.push(0x21); // v2 + PROXY command
|
|
|
+
|
|
|
+ match (remote.ip(), local.ip()) {
|
|
|
+ (IpAddr::V4(src), IpAddr::V4(dst)) => {
|
|
|
+ out.push(0x11); // AF_INET (0x1) + STREAM (0x1) => 0x11
|
|
|
+ out.extend_from_slice(&[0x00, 0x0c]); // len = 12
|
|
|
+ out.extend_from_slice(&src.octets());
|
|
|
+ out.extend_from_slice(&dst.octets());
|
|
|
+ }
|
|
|
+ (IpAddr::V6(src), IpAddr::V6(dst)) => {
|
|
|
+ out.push(0x21); // AF_INET6 (0x2) + STREAM (0x1) => 0x21
|
|
|
+ out.extend_from_slice(&[0x00, 0x24]); // len = 36
|
|
|
+ out.extend_from_slice(&src.octets());
|
|
|
+ out.extend_from_slice(&dst.octets());
|
|
|
+ }
|
|
|
+ _ => panic!("mismatched address families in test"),
|
|
|
+ }
|
|
|
+
|
|
|
+ // src port then dst port
|
|
|
+ out.extend_from_slice(&remote.port().to_be_bytes());
|
|
|
+ out.extend_from_slice(&local.port().to_be_bytes());
|
|
|
+ out
|
|
|
}
|
|
|
|
|
|
#[tokio::test]
|
|
|
@@ -231,27 +305,50 @@ mod proxy_protocol_tests {
|
|
|
let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
|
|
|
let addr = listener.local_addr().unwrap();
|
|
|
|
|
|
- // Create a connection so we get a real TcpStream with real peer/local addrs
|
|
|
let _client = TcpStream::connect(addr).await.unwrap();
|
|
|
let (server, _) = listener.accept().await.unwrap();
|
|
|
|
|
|
let local = server.local_addr().unwrap();
|
|
|
let remote = server.peer_addr().unwrap();
|
|
|
-
|
|
|
assert!(local.is_ipv4());
|
|
|
assert!(remote.is_ipv4());
|
|
|
|
|
|
let expected = expected_v1_header(local, remote);
|
|
|
- let got = generate_proxy_protocol_v1_header(&server).unwrap();
|
|
|
+ let got = generate_proxy_protocol_header(&server, "v1").unwrap();
|
|
|
|
|
|
assert_eq!(got, expected);
|
|
|
- assert!(got.ends_with("\r\n"));
|
|
|
- assert!(got.starts_with("PROXY TCP4 "));
|
|
|
+ assert!(got.ends_with(b"\r\n"));
|
|
|
+ assert!(got.starts_with(b"PROXY TCP4 "));
|
|
|
+ }
|
|
|
+
|
|
|
+ #[tokio::test]
|
|
|
+ async fn v2_header_ipv4_format_is_correct() {
|
|
|
+ let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
|
|
|
+ let addr = listener.local_addr().unwrap();
|
|
|
+
|
|
|
+ let _client = TcpStream::connect(addr).await.unwrap();
|
|
|
+ let (server, _) = listener.accept().await.unwrap();
|
|
|
+
|
|
|
+ let local = server.local_addr().unwrap();
|
|
|
+ let remote = server.peer_addr().unwrap();
|
|
|
+ assert!(local.is_ipv4());
|
|
|
+ assert!(remote.is_ipv4());
|
|
|
+
|
|
|
+ let expected = expected_v2_header(local, remote);
|
|
|
+ let got = generate_proxy_protocol_header(&server, "v2").unwrap();
|
|
|
+
|
|
|
+ assert_eq!(got, expected);
|
|
|
+
|
|
|
+ // Spot-check fixed fields and sizes
|
|
|
+ assert_eq!(&got[..12], &V2_SIG);
|
|
|
+ assert_eq!(got[12], 0x21);
|
|
|
+ assert_eq!(got[13], 0x11);
|
|
|
+ assert_eq!(&got[14..16], &[0x00, 0x0c]);
|
|
|
+ assert_eq!(got.len(), 28);
|
|
|
}
|
|
|
|
|
|
#[tokio::test]
|
|
|
async fn v1_header_ipv6_format_is_correct_or_skipped_if_unavailable() {
|
|
|
- // Some CI environments don’t have IPv6 loopback enabled; skip gracefully.
|
|
|
let listener = match TcpListener::bind("[::1]:0").await {
|
|
|
Ok(l) => l,
|
|
|
Err(_) => return,
|
|
|
@@ -266,43 +363,74 @@ mod proxy_protocol_tests {
|
|
|
|
|
|
let local = server.local_addr().unwrap();
|
|
|
let remote = server.peer_addr().unwrap();
|
|
|
-
|
|
|
assert!(local.is_ipv6());
|
|
|
assert!(remote.is_ipv6());
|
|
|
|
|
|
let expected = expected_v1_header(local, remote);
|
|
|
- let got = generate_proxy_protocol_v1_header(&server).unwrap();
|
|
|
+ let got = generate_proxy_protocol_header(&server, "v1").unwrap();
|
|
|
|
|
|
assert_eq!(got, expected);
|
|
|
- assert!(got.ends_with("\r\n"));
|
|
|
- assert!(got.starts_with("PROXY TCP6 "));
|
|
|
+ assert!(got.ends_with(b"\r\n"));
|
|
|
+ assert!(got.starts_with(b"PROXY TCP6 "));
|
|
|
}
|
|
|
|
|
|
#[tokio::test]
|
|
|
- async fn header_bytes_are_sent_before_payload_when_written_then_forwarded() {
|
|
|
- // This simulates the exact ordering your server code implements:
|
|
|
- // write PROXY header -> flush -> start forwarding bytes. :contentReference[oaicite:1]{index=1}
|
|
|
+ async fn v2_header_ipv6_format_is_correct_or_skipped_if_unavailable() {
|
|
|
+ let listener = match TcpListener::bind("[::1]:0").await {
|
|
|
+ Ok(l) => l,
|
|
|
+ Err(_) => return,
|
|
|
+ };
|
|
|
+ let addr = listener.local_addr().unwrap();
|
|
|
+
|
|
|
+ let _client = match TcpStream::connect(addr).await {
|
|
|
+ Ok(c) => c,
|
|
|
+ Err(_) => return,
|
|
|
+ };
|
|
|
+ let (server, _) = listener.accept().await.unwrap();
|
|
|
+
|
|
|
+ let local = server.local_addr().unwrap();
|
|
|
+ let remote = server.peer_addr().unwrap();
|
|
|
+ assert!(local.is_ipv6());
|
|
|
+ assert!(remote.is_ipv6());
|
|
|
+
|
|
|
+ let expected = expected_v2_header(local, remote);
|
|
|
+ let got = generate_proxy_protocol_header(&server, "v2").unwrap();
|
|
|
+
|
|
|
+ assert_eq!(got, expected);
|
|
|
+ assert_eq!(&got[..12], &V2_SIG);
|
|
|
+ assert_eq!(got[12], 0x21);
|
|
|
+ assert_eq!(got[13], 0x21);
|
|
|
+ assert_eq!(&got[14..16], &[0x00, 0x24]);
|
|
|
+ assert_eq!(got.len(), 52);
|
|
|
+ }
|
|
|
+
|
|
|
+ #[tokio::test]
|
|
|
+ async fn unknown_proxy_protocol_is_rejected() {
|
|
|
+ let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
|
|
|
+ let addr = listener.local_addr().unwrap();
|
|
|
+
|
|
|
+ let _client = TcpStream::connect(addr).await.unwrap();
|
|
|
+ let (server, _) = listener.accept().await.unwrap();
|
|
|
+
|
|
|
+ let err = generate_proxy_protocol_header(&server, "nope").unwrap_err();
|
|
|
+ assert!(err.to_string().contains("Unknown proxy protocol"));
|
|
|
+ }
|
|
|
|
|
|
- // Visitor side connection (incoming connection to the server)
|
|
|
+ async fn header_is_sent_before_payload(version: &'static str) {
|
|
|
let visitor_listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
|
|
|
let visitor_addr = visitor_listener.local_addr().unwrap();
|
|
|
|
|
|
let mut visitor_client = TcpStream::connect(visitor_addr).await.unwrap();
|
|
|
let (mut visitor_server, _) = visitor_listener.accept().await.unwrap();
|
|
|
|
|
|
- let local = visitor_server.local_addr().unwrap();
|
|
|
- let remote = visitor_server.peer_addr().unwrap();
|
|
|
- let expected_header = expected_v1_header(local, remote);
|
|
|
-
|
|
|
+ let expected_header = generate_proxy_protocol_header(&visitor_server, version).unwrap();
|
|
|
let payload = b"hello proxy protocol";
|
|
|
|
|
|
- // Simulate the “data channel” stream with a duplex pipe.
|
|
|
- // (One end is what the server writes into; the other end is what the downstream reads.)
|
|
|
let (mut ch, mut downstream) = tokio::io::duplex(4096);
|
|
|
|
|
|
let server_task = tokio::spawn(async move {
|
|
|
- let header = generate_proxy_protocol_v1_header(&visitor_server).unwrap();
|
|
|
- ch.write_all(header.as_bytes()).await.unwrap();
|
|
|
+ let header = generate_proxy_protocol_header(&visitor_server, version).unwrap();
|
|
|
+ ch.write_all(&header).await.unwrap();
|
|
|
ch.flush().await.unwrap();
|
|
|
|
|
|
tokio::io::copy_bidirectional(&mut visitor_server, &mut ch)
|
|
|
@@ -310,20 +438,27 @@ mod proxy_protocol_tests {
|
|
|
.unwrap();
|
|
|
});
|
|
|
|
|
|
- // Visitor sends payload
|
|
|
visitor_client.write_all(payload).await.unwrap();
|
|
|
visitor_client.shutdown().await.unwrap();
|
|
|
|
|
|
- // Downstream should see header first, then payload
|
|
|
- let mut buf = vec![0u8; expected_header.as_bytes().len() + payload.len()];
|
|
|
+ let mut buf = vec![0u8; expected_header.len() + payload.len()];
|
|
|
downstream.read_exact(&mut buf).await.unwrap();
|
|
|
|
|
|
- let header_len = expected_header.as_bytes().len();
|
|
|
- assert_eq!(&buf[..header_len], expected_header.as_bytes());
|
|
|
+ let header_len = expected_header.len();
|
|
|
+ assert_eq!(&buf[..header_len], &expected_header);
|
|
|
assert_eq!(&buf[header_len..], payload);
|
|
|
|
|
|
- // Close downstream to let copy_bidirectional finish cleanly
|
|
|
drop(downstream);
|
|
|
server_task.await.unwrap();
|
|
|
}
|
|
|
-}
|
|
|
+
|
|
|
+ #[tokio::test]
|
|
|
+ async fn v1_header_bytes_are_sent_before_payload_when_forwarding() {
|
|
|
+ header_is_sent_before_payload("v1").await;
|
|
|
+ }
|
|
|
+
|
|
|
+ #[tokio::test]
|
|
|
+ async fn v2_header_bytes_are_sent_before_payload_when_forwarding() {
|
|
|
+ header_is_sent_before_payload("v2").await;
|
|
|
+ }
|
|
|
+}
|