client.rs 16 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499
  1. use crate::config::{ClientConfig, ClientServiceConfig, Config, TransportType};
  2. use crate::config_watcher::ServiceChange;
  3. use crate::helper::udp_connect;
  4. use crate::protocol::Hello::{self, *};
  5. use crate::protocol::{
  6. self, read_ack, read_control_cmd, read_data_cmd, read_hello, Ack, Auth, ControlChannelCmd,
  7. DataChannelCmd, UdpTraffic, CURRENT_PROTO_VRESION, HASH_WIDTH_IN_BYTES,
  8. };
  9. use crate::transport::{TcpTransport, Transport};
  10. use anyhow::{anyhow, bail, Context, Result};
  11. use backoff::ExponentialBackoff;
  12. use bytes::{Bytes, BytesMut};
  13. use std::collections::HashMap;
  14. use std::net::SocketAddr;
  15. use std::sync::Arc;
  16. use tokio::io::{self, copy_bidirectional, AsyncWriteExt};
  17. use tokio::net::{TcpStream, UdpSocket};
  18. use tokio::sync::{broadcast, mpsc, oneshot, RwLock};
  19. use tokio::time::{self, Duration};
  20. use tracing::{debug, error, info, instrument, trace, warn, Instrument, Span};
  21. #[cfg(feature = "noise")]
  22. use crate::transport::NoiseTransport;
  23. #[cfg(feature = "tls")]
  24. use crate::transport::TlsTransport;
  25. use crate::constants::{UDP_BUFFER_SIZE, UDP_SENDQ_SIZE, UDP_TIMEOUT};
  26. // The entrypoint of running a client
  27. pub async fn run_client(
  28. config: &Config,
  29. shutdown_rx: broadcast::Receiver<bool>,
  30. service_rx: mpsc::Receiver<ServiceChange>,
  31. ) -> Result<()> {
  32. let config = match &config.client {
  33. Some(v) => v,
  34. None => {
  35. return Err(anyhow!("Try to run as a client, but the configuration is missing. Please add the `[client]` block"))
  36. }
  37. };
  38. match config.transport.transport_type {
  39. TransportType::Tcp => {
  40. let mut client = Client::<TcpTransport>::from(config).await?;
  41. client.run(shutdown_rx, service_rx).await
  42. }
  43. TransportType::Tls => {
  44. #[cfg(feature = "tls")]
  45. {
  46. let mut client = Client::<TlsTransport>::from(config).await?;
  47. client.run(shutdown_rx, service_rx).await
  48. }
  49. #[cfg(not(feature = "tls"))]
  50. crate::helper::feature_not_compile("tls")
  51. }
  52. TransportType::Noise => {
  53. #[cfg(feature = "noise")]
  54. {
  55. let mut client = Client::<NoiseTransport>::from(config).await?;
  56. client.run(shutdown_rx, service_rx).await
  57. }
  58. #[cfg(not(feature = "noise"))]
  59. crate::helper::feature_not_compile("noise")
  60. }
  61. }
  62. }
  63. type ServiceDigest = protocol::Digest;
  64. type Nonce = protocol::Digest;
  65. // Holds the state of a client
  66. struct Client<'a, T: Transport> {
  67. config: &'a ClientConfig,
  68. service_handles: HashMap<String, ControlChannelHandle>,
  69. transport: Arc<T>,
  70. }
  71. impl<'a, T: 'static + Transport> Client<'a, T> {
  72. // Create a Client from `[client]` config block
  73. async fn from(config: &'a ClientConfig) -> Result<Client<'a, T>> {
  74. Ok(Client {
  75. config,
  76. service_handles: HashMap::new(),
  77. transport: Arc::new(
  78. T::new(&config.transport)
  79. .await
  80. .with_context(|| "Failed to create the transport")?,
  81. ),
  82. })
  83. }
  84. // The entrypoint of Client
  85. async fn run(
  86. &mut self,
  87. mut shutdown_rx: broadcast::Receiver<bool>,
  88. mut service_rx: mpsc::Receiver<ServiceChange>,
  89. ) -> Result<()> {
  90. for (name, config) in &self.config.services {
  91. // Create a control channel for each service defined
  92. let handle = ControlChannelHandle::new(
  93. (*config).clone(),
  94. self.config.remote_addr.clone(),
  95. self.transport.clone(),
  96. );
  97. self.service_handles.insert(name.clone(), handle);
  98. }
  99. // Wait for the shutdown signal
  100. loop {
  101. tokio::select! {
  102. val = shutdown_rx.recv() => {
  103. match val {
  104. Ok(_) => {}
  105. Err(err) => {
  106. error!("Unable to listen for shutdown signal: {}", err);
  107. }
  108. }
  109. break;
  110. },
  111. e = service_rx.recv() => {
  112. if let Some(e) = e {
  113. match e {
  114. ServiceChange::ClientAdd(s)=> {
  115. let name = s.name.clone();
  116. let handle = ControlChannelHandle::new(
  117. s,
  118. self.config.remote_addr.clone(),
  119. self.transport.clone(),
  120. );
  121. let _ = self.service_handles.insert(name, handle);
  122. },
  123. ServiceChange::ClientDelete(s)=> {
  124. let _ = self.service_handles.remove(&s);
  125. },
  126. _ => ()
  127. }
  128. }
  129. }
  130. }
  131. }
  132. // Shutdown all services
  133. for (_, handle) in self.service_handles.drain() {
  134. handle.shutdown();
  135. }
  136. Ok(())
  137. }
  138. }
  139. struct RunDataChannelArgs<T: Transport> {
  140. session_key: Nonce,
  141. remote_addr: String,
  142. local_addr: String,
  143. connector: Arc<T>,
  144. }
  145. async fn do_data_channel_handshake<T: Transport>(
  146. args: Arc<RunDataChannelArgs<T>>,
  147. ) -> Result<T::Stream> {
  148. // Retry at least every 100ms, at most for 10 seconds
  149. let backoff = ExponentialBackoff {
  150. max_interval: Duration::from_millis(100),
  151. max_elapsed_time: Some(Duration::from_secs(10)),
  152. ..Default::default()
  153. };
  154. // Connect to remote_addr
  155. let mut conn: T::Stream = backoff::future::retry(backoff, || async {
  156. Ok(args
  157. .connector
  158. .connect(&args.remote_addr)
  159. .await
  160. .with_context(|| "Failed to connect to remote_addr")?)
  161. })
  162. .await?;
  163. // Send nonce
  164. let v: &[u8; HASH_WIDTH_IN_BYTES] = args.session_key[..].try_into().unwrap();
  165. let hello = Hello::DataChannelHello(CURRENT_PROTO_VRESION, v.to_owned());
  166. conn.write_all(&bincode::serialize(&hello).unwrap()).await?;
  167. Ok(conn)
  168. }
  169. async fn run_data_channel<T: Transport>(args: Arc<RunDataChannelArgs<T>>) -> Result<()> {
  170. // Do the handshake
  171. let mut conn = do_data_channel_handshake(args.clone()).await?;
  172. // Forward
  173. match read_data_cmd(&mut conn).await? {
  174. DataChannelCmd::StartForwardTcp => {
  175. run_data_channel_for_tcp::<T>(conn, &args.local_addr).await?;
  176. }
  177. DataChannelCmd::StartForwardUdp => {
  178. run_data_channel_for_udp::<T>(conn, &args.local_addr).await?;
  179. }
  180. }
  181. Ok(())
  182. }
  183. // Simply copying back and forth for TCP
  184. #[instrument(skip(conn))]
  185. async fn run_data_channel_for_tcp<T: Transport>(
  186. mut conn: T::Stream,
  187. local_addr: &str,
  188. ) -> Result<()> {
  189. debug!("New data channel starts forwarding");
  190. let mut local = TcpStream::connect(local_addr)
  191. .await
  192. .with_context(|| "Failed to conenct to local_addr")?;
  193. let _ = copy_bidirectional(&mut conn, &mut local).await;
  194. Ok(())
  195. }
  196. // Things get a little tricker when it gets to UDP because it's connectionless.
  197. // A UdpPortMap must be maintained for recent seen incoming address, giving them
  198. // each a local port, which is associated with a socket. So just the sender
  199. // to the socket will work fine for the map's value.
  200. type UdpPortMap = Arc<RwLock<HashMap<SocketAddr, mpsc::Sender<Bytes>>>>;
  201. #[instrument(skip(conn))]
  202. async fn run_data_channel_for_udp<T: Transport>(conn: T::Stream, local_addr: &str) -> Result<()> {
  203. debug!("New data channel starts forwarding");
  204. let port_map: UdpPortMap = Arc::new(RwLock::new(HashMap::new()));
  205. // The channel stores UdpTraffic that needs to be sent to the server
  206. let (outbound_tx, mut outbound_rx) = mpsc::channel::<UdpTraffic>(UDP_SENDQ_SIZE);
  207. // FIXME: https://github.com/tokio-rs/tls/issues/40
  208. // Maybe this is our concern
  209. let (mut rd, mut wr) = io::split(conn);
  210. // Keep sending items from the outbound channel to the server
  211. tokio::spawn(async move {
  212. while let Some(t) = outbound_rx.recv().await {
  213. trace!("outbound {:?}", t);
  214. if let Err(e) = t
  215. .write(&mut wr)
  216. .await
  217. .with_context(|| "Failed to forward UDP traffic to the server")
  218. {
  219. debug!("{:?}", e);
  220. break;
  221. }
  222. }
  223. });
  224. loop {
  225. // Read a packet from the server
  226. let hdr_len = rd.read_u16().await?;
  227. let packet = UdpTraffic::read(&mut rd, hdr_len)
  228. .await
  229. .with_context(|| "Failed to read UDPTraffic from the server")?;
  230. let m = port_map.read().await;
  231. if m.get(&packet.from).is_none() {
  232. // This packet is from a address we don't see for a while,
  233. // which is not in the UdpPortMap.
  234. // So set up a mapping (and a forwarder) for it
  235. // Drop the reader lock
  236. drop(m);
  237. // Grab the writer lock
  238. // This is the only thread that will try to grab the writer lock
  239. // So no need to worry about some other thread has already set up
  240. // the mapping between the gap of dropping the reader lock and
  241. // grabbing the writer lock
  242. let mut m = port_map.write().await;
  243. match udp_connect(local_addr).await {
  244. Ok(s) => {
  245. let (inbound_tx, inbound_rx) = mpsc::channel(UDP_SENDQ_SIZE);
  246. m.insert(packet.from, inbound_tx);
  247. tokio::spawn(run_udp_forwarder(
  248. s,
  249. inbound_rx,
  250. outbound_tx.clone(),
  251. packet.from,
  252. port_map.clone(),
  253. ));
  254. }
  255. Err(e) => {
  256. error!("{:?}", e);
  257. }
  258. }
  259. }
  260. // Now there should be a udp forwarder that can receive the packet
  261. let m = port_map.read().await;
  262. if let Some(tx) = m.get(&packet.from) {
  263. let _ = tx.send(packet.data).await;
  264. }
  265. }
  266. }
  267. // Run a UdpSocket for the visitor `from`
  268. #[instrument(skip_all, fields(from))]
  269. async fn run_udp_forwarder(
  270. s: UdpSocket,
  271. mut inbound_rx: mpsc::Receiver<Bytes>,
  272. outbount_tx: mpsc::Sender<UdpTraffic>,
  273. from: SocketAddr,
  274. port_map: UdpPortMap,
  275. ) -> Result<()> {
  276. debug!("Forwarder created");
  277. let mut buf = BytesMut::new();
  278. buf.resize(UDP_BUFFER_SIZE, 0);
  279. loop {
  280. tokio::select! {
  281. // Receive from the server
  282. data = inbound_rx.recv() => {
  283. if let Some(data) = data {
  284. s.send(&data).await?;
  285. } else {
  286. break;
  287. }
  288. },
  289. // Receive from the service
  290. val = s.recv(&mut buf) => {
  291. let len = match val {
  292. Ok(v) => v,
  293. Err(_) => {break;}
  294. };
  295. let t = UdpTraffic{
  296. from,
  297. data: Bytes::copy_from_slice(&buf[..len])
  298. };
  299. outbount_tx.send(t).await?;
  300. },
  301. // No traffic for the duration of UDP_TIMEOUT, clean up the state
  302. _ = time::sleep(Duration::from_secs(UDP_TIMEOUT)) => {
  303. break;
  304. }
  305. }
  306. }
  307. let mut port_map = port_map.write().await;
  308. port_map.remove(&from);
  309. debug!("Forwarder dropped");
  310. Ok(())
  311. }
  312. // Control channel, using T as the transport layer
  313. struct ControlChannel<T: Transport> {
  314. digest: ServiceDigest, // SHA256 of the service name
  315. service: ClientServiceConfig, // `[client.services.foo]` config block
  316. shutdown_rx: oneshot::Receiver<u8>, // Receives the shutdown signal
  317. remote_addr: String, // `client.remote_addr`
  318. transport: Arc<T>, // Wrapper around the transport layer
  319. }
  320. // Handle of a control channel
  321. // Dropping it will also drop the actual control channel
  322. struct ControlChannelHandle {
  323. shutdown_tx: oneshot::Sender<u8>,
  324. }
  325. impl<T: 'static + Transport> ControlChannel<T> {
  326. #[instrument(skip_all)]
  327. async fn run(&mut self) -> Result<()> {
  328. let mut conn = self
  329. .transport
  330. .connect(&self.remote_addr)
  331. .await
  332. .with_context(|| format!("Failed to connect to the server: {}", &self.remote_addr))?;
  333. // Send hello
  334. debug!("Sending hello");
  335. let hello_send =
  336. Hello::ControlChannelHello(CURRENT_PROTO_VRESION, self.digest[..].try_into().unwrap());
  337. conn.write_all(&bincode::serialize(&hello_send).unwrap())
  338. .await?;
  339. // Read hello
  340. debug!("Reading hello");
  341. let nonce = match read_hello(&mut conn)
  342. .await
  343. .with_context(|| "Failed to read hello from the server")?
  344. {
  345. ControlChannelHello(_, d) => d,
  346. _ => {
  347. bail!("Unexpected type of hello");
  348. }
  349. };
  350. // Send auth
  351. debug!("Sending auth");
  352. let mut concat = Vec::from(self.service.token.as_ref().unwrap().as_bytes());
  353. concat.extend_from_slice(&nonce);
  354. let session_key = protocol::digest(&concat);
  355. let auth = Auth(session_key);
  356. conn.write_all(&bincode::serialize(&auth).unwrap()).await?;
  357. // Read ack
  358. debug!("Reading ack");
  359. match read_ack(&mut conn).await? {
  360. Ack::Ok => {}
  361. v => {
  362. return Err(anyhow!("{}", v))
  363. .with_context(|| format!("Authentication failed: {}", self.service.name));
  364. }
  365. }
  366. // Channel ready
  367. info!("Control channel established");
  368. let remote_addr = self.remote_addr.clone();
  369. let local_addr = self.service.local_addr.clone();
  370. let data_ch_args = Arc::new(RunDataChannelArgs {
  371. session_key,
  372. remote_addr,
  373. local_addr,
  374. connector: self.transport.clone(),
  375. });
  376. loop {
  377. tokio::select! {
  378. val = read_control_cmd(&mut conn) => {
  379. let val = val?;
  380. debug!( "Received {:?}", val);
  381. match val {
  382. ControlChannelCmd::CreateDataChannel => {
  383. let args = data_ch_args.clone();
  384. tokio::spawn(async move {
  385. if let Err(e) = run_data_channel(args).await.with_context(|| "Failed to run the data channel") {
  386. error!("{:?}", e);
  387. }
  388. }.instrument(Span::current()));
  389. }
  390. }
  391. },
  392. _ = &mut self.shutdown_rx => {
  393. info!( "Control channel shutting down...");
  394. break;
  395. }
  396. }
  397. }
  398. Ok(())
  399. }
  400. }
  401. impl ControlChannelHandle {
  402. #[instrument(skip_all, fields(service = %service.name))]
  403. fn new<T: 'static + Transport>(
  404. service: ClientServiceConfig,
  405. remote_addr: String,
  406. transport: Arc<T>,
  407. ) -> ControlChannelHandle {
  408. let digest = protocol::digest(service.name.as_bytes());
  409. info!("Starting {}", hex::encode(digest));
  410. let (shutdown_tx, shutdown_rx) = oneshot::channel();
  411. let mut s = ControlChannel {
  412. digest,
  413. service,
  414. shutdown_rx,
  415. remote_addr,
  416. transport,
  417. };
  418. tokio::spawn(
  419. async move {
  420. while let Err(err) = s
  421. .run()
  422. .await
  423. .with_context(|| "Failed to run the control channel")
  424. {
  425. if s.shutdown_rx.try_recv() != Err(oneshot::error::TryRecvError::Empty) {
  426. break;
  427. }
  428. let duration = Duration::from_secs(1);
  429. error!("{:?}\n\nRetry in {:?}...", err, duration);
  430. time::sleep(duration).await;
  431. }
  432. }
  433. .instrument(Span::current()),
  434. );
  435. ControlChannelHandle { shutdown_tx }
  436. }
  437. fn shutdown(self) {
  438. // A send failure shows that the actor has already shutdown.
  439. let _ = self.shutdown_tx.send(0u8);
  440. }
  441. }