client.rs 14 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444
  1. use crate::config::{ClientConfig, ClientServiceConfig, Config, TransportType};
  2. use crate::helper::udp_connect;
  3. use crate::protocol::Hello::{self, *};
  4. use crate::protocol::{
  5. self, read_ack, read_control_cmd, read_data_cmd, read_hello, Ack, Auth, ControlChannelCmd,
  6. DataChannelCmd, UdpTraffic, CURRENT_PROTO_VRESION, HASH_WIDTH_IN_BYTES,
  7. };
  8. use crate::transport::{NoiseTransport, TcpTransport, Transport};
  9. use anyhow::{anyhow, bail, Context, Result};
  10. use backoff::ExponentialBackoff;
  11. use bytes::{Bytes, BytesMut};
  12. use std::collections::HashMap;
  13. use std::net::SocketAddr;
  14. use std::sync::Arc;
  15. use tokio::io::{self, copy_bidirectional, AsyncWriteExt};
  16. use tokio::net::{TcpStream, UdpSocket};
  17. use tokio::sync::{broadcast, mpsc, oneshot, RwLock};
  18. use tokio::time::{self, Duration};
  19. use tracing::{debug, error, info, instrument, Instrument, Span};
  20. #[cfg(feature = "tls")]
  21. use crate::transport::TlsTransport;
  22. use crate::constants::{UDP_BUFFER_SIZE, UDP_SENDQ_SIZE, UDP_TIMEOUT};
  23. // The entrypoint of running a client
  24. pub async fn run_client(config: &Config, shutdown_rx: broadcast::Receiver<bool>) -> Result<()> {
  25. let config = match &config.client {
  26. Some(v) => v,
  27. None => {
  28. return Err(anyhow!("Try to run as a client, but the configuration is missing. Please add the `[client]` block"))
  29. }
  30. };
  31. match config.transport.transport_type {
  32. TransportType::Tcp => {
  33. let mut client = Client::<TcpTransport>::from(config).await?;
  34. client.run(shutdown_rx).await
  35. }
  36. TransportType::Tls => {
  37. #[cfg(feature = "tls")]
  38. {
  39. let mut client = Client::<TlsTransport>::from(config).await?;
  40. client.run(shutdown_rx).await
  41. }
  42. #[cfg(not(feature = "tls"))]
  43. crate::helper::feature_not_compile("tls")
  44. }
  45. TransportType::Noise => {
  46. let mut client = Client::<NoiseTransport>::from(config).await?;
  47. client.run(shutdown_rx).await
  48. }
  49. }
  50. }
  51. type ServiceDigest = protocol::Digest;
  52. type Nonce = protocol::Digest;
  53. // Holds the state of a client
  54. struct Client<'a, T: Transport> {
  55. config: &'a ClientConfig,
  56. service_handles: HashMap<String, ControlChannelHandle>,
  57. transport: Arc<T>,
  58. }
  59. impl<'a, T: 'static + Transport> Client<'a, T> {
  60. // Create a Client from `[client]` config block
  61. async fn from(config: &'a ClientConfig) -> Result<Client<'a, T>> {
  62. Ok(Client {
  63. config,
  64. service_handles: HashMap::new(),
  65. transport: Arc::new(
  66. T::new(&config.transport)
  67. .await
  68. .with_context(|| "Failed to create the transport")?,
  69. ),
  70. })
  71. }
  72. // The entrypoint of Client
  73. async fn run(&mut self, mut shutdown_rx: broadcast::Receiver<bool>) -> Result<()> {
  74. for (name, config) in &self.config.services {
  75. // Create a control channel for each service defined
  76. let handle = ControlChannelHandle::new(
  77. (*config).clone(),
  78. self.config.remote_addr.clone(),
  79. self.transport.clone(),
  80. );
  81. self.service_handles.insert(name.clone(), handle);
  82. }
  83. // TODO: Maybe wait for a config change signal for hot reloading
  84. // Wait for the shutdown signal
  85. loop {
  86. tokio::select! {
  87. val = shutdown_rx.recv() => {
  88. match val {
  89. Ok(_) => {}
  90. Err(err) => {
  91. error!("Unable to listen for shutdown signal: {}", err);
  92. }
  93. }
  94. break;
  95. },
  96. }
  97. }
  98. // Shutdown all services
  99. for (_, handle) in self.service_handles.drain() {
  100. handle.shutdown();
  101. }
  102. Ok(())
  103. }
  104. }
  105. struct RunDataChannelArgs<T: Transport> {
  106. session_key: Nonce,
  107. remote_addr: String,
  108. local_addr: String,
  109. connector: Arc<T>,
  110. }
  111. async fn do_data_channel_handshake<T: Transport>(
  112. args: Arc<RunDataChannelArgs<T>>,
  113. ) -> Result<T::Stream> {
  114. // Retry at least every 100ms, at most for 10 seconds
  115. let backoff = ExponentialBackoff {
  116. max_interval: Duration::from_millis(100),
  117. max_elapsed_time: Some(Duration::from_secs(10)),
  118. ..Default::default()
  119. };
  120. // Connect to remote_addr
  121. let mut conn: T::Stream = backoff::future::retry(backoff, || async {
  122. Ok(args
  123. .connector
  124. .connect(&args.remote_addr)
  125. .await
  126. .with_context(|| "Failed to connect to remote_addr")?)
  127. })
  128. .await?;
  129. // Send nonce
  130. let v: &[u8; HASH_WIDTH_IN_BYTES] = args.session_key[..].try_into().unwrap();
  131. let hello = Hello::DataChannelHello(CURRENT_PROTO_VRESION, v.to_owned());
  132. conn.write_all(&bincode::serialize(&hello).unwrap()).await?;
  133. Ok(conn)
  134. }
  135. async fn run_data_channel<T: Transport>(args: Arc<RunDataChannelArgs<T>>) -> Result<()> {
  136. // Do the handshake
  137. let mut conn = do_data_channel_handshake(args.clone()).await?;
  138. // Forward
  139. match read_data_cmd(&mut conn).await? {
  140. DataChannelCmd::StartForwardTcp => {
  141. run_data_channel_for_tcp::<T>(conn, &args.local_addr).await?;
  142. }
  143. DataChannelCmd::StartForwardUdp => {
  144. run_data_channel_for_udp::<T>(conn, &args.local_addr).await?;
  145. }
  146. }
  147. Ok(())
  148. }
  149. // Simply copying back and forth for TCP
  150. #[instrument(skip(conn))]
  151. async fn run_data_channel_for_tcp<T: Transport>(
  152. mut conn: T::Stream,
  153. local_addr: &str,
  154. ) -> Result<()> {
  155. debug!("New data channel starts forwarding");
  156. let mut local = TcpStream::connect(local_addr)
  157. .await
  158. .with_context(|| "Failed to conenct to local_addr")?;
  159. let _ = copy_bidirectional(&mut conn, &mut local).await;
  160. Ok(())
  161. }
  162. // Things get a little tricker when it gets to UDP because it's connectionless.
  163. // A UdpPortMap must be maintained for recent seen incoming address, giving them
  164. // each a local port, which is associated with a socket. So just the sender
  165. // to the socket will work fine for the map's value.
  166. type UdpPortMap = Arc<RwLock<HashMap<SocketAddr, mpsc::Sender<Bytes>>>>;
  167. #[instrument(skip(conn))]
  168. async fn run_data_channel_for_udp<T: Transport>(conn: T::Stream, local_addr: &str) -> Result<()> {
  169. debug!("New data channel starts forwarding");
  170. let port_map: UdpPortMap = Arc::new(RwLock::new(HashMap::new()));
  171. // The channel stores UdpTraffic that needs to be sent to the server
  172. let (outbound_tx, mut outbound_rx) = mpsc::channel::<UdpTraffic>(UDP_SENDQ_SIZE);
  173. // FIXME: https://github.com/tokio-rs/tls/issues/40
  174. // Maybe this is our concern
  175. let (mut rd, mut wr) = io::split(conn);
  176. // Keep sending items from the outbound channel to the server
  177. tokio::spawn(async move {
  178. while let Some(t) = outbound_rx.recv().await {
  179. debug!("outbound {:?}", t);
  180. if t.write(&mut wr).await.is_err() {
  181. break;
  182. }
  183. }
  184. });
  185. loop {
  186. // Read a packet from the server
  187. let packet = UdpTraffic::read(&mut rd).await?;
  188. let m = port_map.read().await;
  189. if m.get(&packet.from).is_none() {
  190. // This packet is from a address we don't see for a while,
  191. // which is not in the UdpPortMap.
  192. // So set up a mapping (and a forwarder) for it
  193. // Drop the reader lock
  194. drop(m);
  195. // Grab the writer lock
  196. // This is the only thread that will try to grab the writer lock
  197. // So no need to worry about some other thread has already set up
  198. // the mapping between the gap of dropping the reader lock and
  199. // grabbing the writer lock
  200. let mut m = port_map.write().await;
  201. match udp_connect(local_addr).await {
  202. Ok(s) => {
  203. let (inbound_tx, inbound_rx) = mpsc::channel(UDP_SENDQ_SIZE);
  204. m.insert(packet.from, inbound_tx);
  205. tokio::spawn(run_udp_forwarder(
  206. s,
  207. inbound_rx,
  208. outbound_tx.clone(),
  209. packet.from,
  210. port_map.clone(),
  211. ));
  212. }
  213. Err(e) => {
  214. error!("{:?}", e);
  215. }
  216. }
  217. }
  218. // Now there should be a udp forwarder that can receive the packet
  219. let m = port_map.read().await;
  220. if let Some(tx) = m.get(&packet.from) {
  221. let _ = tx.send(packet.data).await;
  222. }
  223. }
  224. }
  225. // Run a UdpSocket for the visitor `from`
  226. async fn run_udp_forwarder(
  227. s: UdpSocket,
  228. mut inbound_rx: mpsc::Receiver<Bytes>,
  229. outbount_tx: mpsc::Sender<UdpTraffic>,
  230. from: SocketAddr,
  231. port_map: UdpPortMap,
  232. ) -> Result<()> {
  233. let mut buf = BytesMut::new();
  234. buf.resize(UDP_BUFFER_SIZE, 0);
  235. loop {
  236. tokio::select! {
  237. // Receive from the server
  238. data = inbound_rx.recv() => {
  239. if let Some(data) = data {
  240. s.send(&data).await?;
  241. } else {
  242. break;
  243. }
  244. },
  245. // Receive from the service
  246. val = s.recv(&mut buf) => {
  247. let len = match val {
  248. Ok(v) => v,
  249. Err(_) => {break;}
  250. };
  251. let t = UdpTraffic{
  252. from,
  253. data: Bytes::copy_from_slice(&buf[..len])
  254. };
  255. outbount_tx.send(t).await?;
  256. },
  257. // No traffic for the duration of UDP_TIMEOUT, clean up the state
  258. _ = time::sleep(Duration::from_secs(UDP_TIMEOUT)) => {
  259. break;
  260. }
  261. }
  262. }
  263. let mut port_map = port_map.write().await;
  264. port_map.remove(&from);
  265. Ok(())
  266. }
  267. // Control channel, using T as the transport layer
  268. struct ControlChannel<T: Transport> {
  269. digest: ServiceDigest, // SHA256 of the service name
  270. service: ClientServiceConfig, // `[client.services.foo]` config block
  271. shutdown_rx: oneshot::Receiver<u8>, // Receives the shutdown signal
  272. remote_addr: String, // `client.remote_addr`
  273. transport: Arc<T>, // Wrapper around the transport layer
  274. }
  275. // Handle of a control channel
  276. // Dropping it will also drop the actual control channel
  277. struct ControlChannelHandle {
  278. shutdown_tx: oneshot::Sender<u8>,
  279. }
  280. impl<T: 'static + Transport> ControlChannel<T> {
  281. #[instrument(skip_all)]
  282. async fn run(&mut self) -> Result<()> {
  283. let mut conn = self
  284. .transport
  285. .connect(&self.remote_addr)
  286. .await
  287. .with_context(|| format!("Failed to connect to the server: {}", &self.remote_addr))?;
  288. // Send hello
  289. let hello_send =
  290. Hello::ControlChannelHello(CURRENT_PROTO_VRESION, self.digest[..].try_into().unwrap());
  291. conn.write_all(&bincode::serialize(&hello_send).unwrap())
  292. .await?;
  293. // Read hello))
  294. let nonce = match read_hello(&mut conn)
  295. .await
  296. .with_context(|| "Failed to read hello from the server")?
  297. {
  298. ControlChannelHello(_, d) => d,
  299. _ => {
  300. bail!("Unexpected type of hello");
  301. }
  302. };
  303. // Send auth
  304. let mut concat = Vec::from(self.service.token.as_ref().unwrap().as_bytes());
  305. concat.extend_from_slice(&nonce);
  306. let session_key = protocol::digest(&concat);
  307. let auth = Auth(session_key);
  308. conn.write_all(&bincode::serialize(&auth).unwrap()).await?;
  309. // Read ack
  310. match read_ack(&mut conn).await? {
  311. Ack::Ok => {}
  312. v => {
  313. return Err(anyhow!("{}", v))
  314. .with_context(|| format!("Authentication failed: {}", self.service.name));
  315. }
  316. }
  317. // Channel ready
  318. info!("Control channel established");
  319. let remote_addr = self.remote_addr.clone();
  320. let local_addr = self.service.local_addr.clone();
  321. let data_ch_args = Arc::new(RunDataChannelArgs {
  322. session_key,
  323. remote_addr,
  324. local_addr,
  325. connector: self.transport.clone(),
  326. });
  327. loop {
  328. tokio::select! {
  329. val = read_control_cmd(&mut conn) => {
  330. let val = val?;
  331. debug!( "Received {:?}", val);
  332. match val {
  333. ControlChannelCmd::CreateDataChannel => {
  334. let args = data_ch_args.clone();
  335. tokio::spawn(async move {
  336. if let Err(e) = run_data_channel(args).await.with_context(|| "Failed to run the data channel") {
  337. error!("{:?}", e);
  338. }
  339. }.instrument(Span::current()));
  340. }
  341. }
  342. },
  343. _ = &mut self.shutdown_rx => {
  344. info!( "Shutting down gracefully...");
  345. break;
  346. }
  347. }
  348. }
  349. Ok(())
  350. }
  351. }
  352. impl ControlChannelHandle {
  353. #[instrument(skip_all, fields(service = %service.name))]
  354. fn new<T: 'static + Transport>(
  355. service: ClientServiceConfig,
  356. remote_addr: String,
  357. transport: Arc<T>,
  358. ) -> ControlChannelHandle {
  359. let digest = protocol::digest(service.name.as_bytes());
  360. let (shutdown_tx, shutdown_rx) = oneshot::channel();
  361. let mut s = ControlChannel {
  362. digest,
  363. service,
  364. shutdown_rx,
  365. remote_addr,
  366. transport,
  367. };
  368. tokio::spawn(
  369. async move {
  370. while let Err(err) = s
  371. .run()
  372. .await
  373. .with_context(|| "Failed to run the control channel")
  374. {
  375. let duration = Duration::from_secs(1);
  376. error!("{:?}\n\nRetry in {:?}...", err, duration);
  377. time::sleep(duration).await;
  378. }
  379. }
  380. .instrument(Span::current()),
  381. );
  382. ControlChannelHandle { shutdown_tx }
  383. }
  384. fn shutdown(self) {
  385. // A send failure shows that the actor has already shutdown.
  386. let _ = self.shutdown_tx.send(0u8);
  387. }
  388. }