protocol.rs 6.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220
  1. pub const HASH_WIDTH_IN_BYTES: usize = 32;
  2. use anyhow::{Context, Result};
  3. use bytes::{Bytes, BytesMut};
  4. use lazy_static::lazy_static;
  5. use serde::{Deserialize, Serialize};
  6. use std::net::SocketAddr;
  7. use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt};
  8. use tracing::trace;
  9. type ProtocolVersion = u8;
  10. const PROTO_V0: u8 = 0u8;
  11. pub const CURRENT_PROTO_VRESION: ProtocolVersion = PROTO_V0;
  12. pub type Digest = [u8; HASH_WIDTH_IN_BYTES];
  13. #[derive(Deserialize, Serialize, Debug)]
  14. pub enum Hello {
  15. ControlChannelHello(ProtocolVersion, Digest), // sha256sum(service name) or a nonce
  16. DataChannelHello(ProtocolVersion, Digest), // token provided by CreateDataChannel
  17. }
  18. #[derive(Deserialize, Serialize, Debug)]
  19. pub struct Auth(pub Digest);
  20. #[derive(Deserialize, Serialize, Debug)]
  21. pub enum Ack {
  22. Ok,
  23. ServiceNotExist,
  24. AuthFailed,
  25. }
  26. impl std::fmt::Display for Ack {
  27. fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
  28. write!(
  29. f,
  30. "{}",
  31. match self {
  32. Ack::Ok => "Ok",
  33. Ack::ServiceNotExist => "Service not exist",
  34. Ack::AuthFailed => "Incorrect token",
  35. }
  36. )
  37. }
  38. }
  39. #[derive(Deserialize, Serialize, Debug)]
  40. pub enum ControlChannelCmd {
  41. CreateDataChannel,
  42. }
  43. #[derive(Deserialize, Serialize, Debug)]
  44. pub enum DataChannelCmd {
  45. StartForwardTcp,
  46. StartForwardUdp,
  47. }
  48. type UdpPacketLen = u16; // `u16` should be enough for any practical UDP traffic on the Internet
  49. #[derive(Deserialize, Serialize, Debug)]
  50. struct UdpHeader {
  51. from: SocketAddr,
  52. len: UdpPacketLen,
  53. }
  54. #[derive(Debug)]
  55. pub struct UdpTraffic {
  56. pub from: SocketAddr,
  57. pub data: Bytes,
  58. }
  59. impl UdpTraffic {
  60. pub async fn write<T: AsyncWrite + Unpin>(&self, writer: &mut T) -> Result<()> {
  61. let hdr = UdpHeader {
  62. from: self.from,
  63. len: self.data.len() as UdpPacketLen,
  64. };
  65. let v = bincode::serialize(&hdr).unwrap();
  66. trace!("Write {:?} of length {}", hdr, v.len());
  67. writer.write_u16(v.len() as u16).await?;
  68. writer.write_all(&v).await?;
  69. writer.write_all(&self.data).await?;
  70. Ok(())
  71. }
  72. #[allow(dead_code)]
  73. pub async fn write_slice<T: AsyncWrite + Unpin>(
  74. writer: &mut T,
  75. from: SocketAddr,
  76. data: &[u8],
  77. ) -> Result<()> {
  78. let hdr = UdpHeader {
  79. from,
  80. len: data.len() as UdpPacketLen,
  81. };
  82. let v = bincode::serialize(&hdr).unwrap();
  83. trace!("Write {:?} of length {}", hdr, v.len());
  84. writer.write_u16(v.len() as u16).await?;
  85. writer.write_all(&v).await?;
  86. writer.write_all(data).await?;
  87. Ok(())
  88. }
  89. pub async fn read<T: AsyncRead + Unpin>(reader: &mut T, hdr_len: u16) -> Result<UdpTraffic> {
  90. let mut buf = Vec::new();
  91. buf.resize(hdr_len as usize, 0);
  92. reader
  93. .read_exact(&mut buf)
  94. .await
  95. .with_context(|| "Failed to read udp header")?;
  96. let hdr: UdpHeader =
  97. bincode::deserialize(&buf).with_context(|| "Failed to deserialize UdpHeader")?;
  98. trace!("hdr {:?}", hdr);
  99. let mut data = BytesMut::new();
  100. data.resize(hdr.len as usize, 0);
  101. reader.read_exact(&mut data).await?;
  102. Ok(UdpTraffic {
  103. from: hdr.from,
  104. data: data.freeze(),
  105. })
  106. }
  107. }
  108. pub fn digest(data: &[u8]) -> Digest {
  109. use sha2::{Digest, Sha256};
  110. let d = Sha256::new().chain_update(data).finalize();
  111. d.into()
  112. }
  113. struct PacketLength {
  114. hello: usize,
  115. ack: usize,
  116. auth: usize,
  117. c_cmd: usize,
  118. d_cmd: usize,
  119. }
  120. impl PacketLength {
  121. pub fn new() -> PacketLength {
  122. let username = "default";
  123. let d = digest(username.as_bytes());
  124. let hello = bincode::serialized_size(&Hello::ControlChannelHello(CURRENT_PROTO_VRESION, d))
  125. .unwrap() as usize;
  126. let c_cmd =
  127. bincode::serialized_size(&ControlChannelCmd::CreateDataChannel).unwrap() as usize;
  128. let d_cmd = bincode::serialized_size(&DataChannelCmd::StartForwardTcp).unwrap() as usize;
  129. let ack = Ack::Ok;
  130. let ack = bincode::serialized_size(&ack).unwrap() as usize;
  131. let auth = bincode::serialized_size(&Auth(d)).unwrap() as usize;
  132. PacketLength {
  133. hello,
  134. ack,
  135. auth,
  136. c_cmd,
  137. d_cmd,
  138. }
  139. }
  140. }
  141. lazy_static! {
  142. static ref PACKET_LEN: PacketLength = PacketLength::new();
  143. }
  144. pub async fn read_hello<T: AsyncRead + AsyncWrite + Unpin>(conn: &mut T) -> Result<Hello> {
  145. let mut buf = vec![0u8; PACKET_LEN.hello];
  146. conn.read_exact(&mut buf)
  147. .await
  148. .with_context(|| "Failed to read hello")?;
  149. let hello = bincode::deserialize(&buf).with_context(|| "Failed to deserialize hello")?;
  150. Ok(hello)
  151. }
  152. pub async fn read_auth<T: AsyncRead + AsyncWrite + Unpin>(conn: &mut T) -> Result<Auth> {
  153. let mut buf = vec![0u8; PACKET_LEN.auth];
  154. conn.read_exact(&mut buf)
  155. .await
  156. .with_context(|| "Failed to read auth")?;
  157. bincode::deserialize(&buf).with_context(|| "Failed to deserialize auth")
  158. }
  159. pub async fn read_ack<T: AsyncRead + AsyncWrite + Unpin>(conn: &mut T) -> Result<Ack> {
  160. let mut bytes = vec![0u8; PACKET_LEN.ack];
  161. conn.read_exact(&mut bytes)
  162. .await
  163. .with_context(|| "Failed to read ack")?;
  164. bincode::deserialize(&bytes).with_context(|| "Failed to deserialize ack")
  165. }
  166. pub async fn read_control_cmd<T: AsyncRead + AsyncWrite + Unpin>(
  167. conn: &mut T,
  168. ) -> Result<ControlChannelCmd> {
  169. let mut bytes = vec![0u8; PACKET_LEN.c_cmd];
  170. conn.read_exact(&mut bytes)
  171. .await
  172. .with_context(|| "Failed to read control cmd")?;
  173. bincode::deserialize(&bytes).with_context(|| "Failed to deserialize control cmd")
  174. }
  175. pub async fn read_data_cmd<T: AsyncRead + AsyncWrite + Unpin>(
  176. conn: &mut T,
  177. ) -> Result<DataChannelCmd> {
  178. let mut bytes = vec![0u8; PACKET_LEN.d_cmd];
  179. conn.read_exact(&mut bytes)
  180. .await
  181. .with_context(|| "Failed to read data cmd")?;
  182. bincode::deserialize(&bytes).with_context(|| "Failed to deserialize data cmd")
  183. }