client.rs 16 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506
  1. use crate::config::{ClientConfig, ClientServiceConfig, Config, TransportType};
  2. use crate::config_watcher::ServiceChange;
  3. use crate::helper::udp_connect;
  4. use crate::protocol::Hello::{self, *};
  5. use crate::protocol::{
  6. self, read_ack, read_control_cmd, read_data_cmd, read_hello, Ack, Auth, ControlChannelCmd,
  7. DataChannelCmd, UdpTraffic, CURRENT_PROTO_VERSION, HASH_WIDTH_IN_BYTES,
  8. };
  9. use crate::transport::{TcpTransport, Transport};
  10. use anyhow::{anyhow, bail, Context, Result};
  11. use backoff::ExponentialBackoff;
  12. use bytes::{Bytes, BytesMut};
  13. use std::collections::HashMap;
  14. use std::net::SocketAddr;
  15. use std::sync::Arc;
  16. use tokio::io::{self, copy_bidirectional, AsyncReadExt, AsyncWriteExt};
  17. use tokio::net::{TcpStream, UdpSocket};
  18. use tokio::sync::{broadcast, mpsc, oneshot, RwLock};
  19. use tokio::time::{self, Duration};
  20. use tracing::{debug, error, info, instrument, trace, warn, Instrument, Span};
  21. #[cfg(feature = "noise")]
  22. use crate::transport::NoiseTransport;
  23. #[cfg(feature = "tls")]
  24. use crate::transport::TlsTransport;
  25. use crate::constants::{UDP_BUFFER_SIZE, UDP_SENDQ_SIZE, UDP_TIMEOUT};
  26. // The entrypoint of running a client
  27. pub async fn run_client(
  28. config: &Config,
  29. shutdown_rx: broadcast::Receiver<bool>,
  30. service_rx: mpsc::Receiver<ServiceChange>,
  31. ) -> Result<()> {
  32. let config = match &config.client {
  33. Some(v) => v,
  34. None => {
  35. return Err(anyhow!("Try to run as a client, but the configuration is missing. Please add the `[client]` block"))
  36. }
  37. };
  38. match config.transport.transport_type {
  39. TransportType::Tcp => {
  40. let mut client = Client::<TcpTransport>::from(config).await?;
  41. client.run(shutdown_rx, service_rx).await
  42. }
  43. TransportType::Tls => {
  44. #[cfg(feature = "tls")]
  45. {
  46. let mut client = Client::<TlsTransport>::from(config).await?;
  47. client.run(shutdown_rx, service_rx).await
  48. }
  49. #[cfg(not(feature = "tls"))]
  50. crate::helper::feature_not_compile("tls")
  51. }
  52. TransportType::Noise => {
  53. #[cfg(feature = "noise")]
  54. {
  55. let mut client = Client::<NoiseTransport>::from(config).await?;
  56. client.run(shutdown_rx, service_rx).await
  57. }
  58. #[cfg(not(feature = "noise"))]
  59. crate::helper::feature_not_compile("noise")
  60. }
  61. }
  62. }
  63. type ServiceDigest = protocol::Digest;
  64. type Nonce = protocol::Digest;
  65. // Holds the state of a client
  66. struct Client<'a, T: Transport> {
  67. config: &'a ClientConfig,
  68. service_handles: HashMap<String, ControlChannelHandle>,
  69. transport: Arc<T>,
  70. }
  71. impl<'a, T: 'static + Transport> Client<'a, T> {
  72. // Create a Client from `[client]` config block
  73. async fn from(config: &'a ClientConfig) -> Result<Client<'a, T>> {
  74. Ok(Client {
  75. config,
  76. service_handles: HashMap::new(),
  77. transport: Arc::new(
  78. T::new(&config.transport)
  79. .await
  80. .with_context(|| "Failed to create the transport")?,
  81. ),
  82. })
  83. }
  84. // The entrypoint of Client
  85. async fn run(
  86. &mut self,
  87. mut shutdown_rx: broadcast::Receiver<bool>,
  88. mut service_rx: mpsc::Receiver<ServiceChange>,
  89. ) -> Result<()> {
  90. for (name, config) in &self.config.services {
  91. // Create a control channel for each service defined
  92. let handle = ControlChannelHandle::new(
  93. (*config).clone(),
  94. self.config.remote_addr.clone(),
  95. self.transport.clone(),
  96. );
  97. self.service_handles.insert(name.clone(), handle);
  98. }
  99. // Wait for the shutdown signal
  100. loop {
  101. tokio::select! {
  102. val = shutdown_rx.recv() => {
  103. match val {
  104. Ok(_) => {}
  105. Err(err) => {
  106. error!("Unable to listen for shutdown signal: {}", err);
  107. }
  108. }
  109. break;
  110. },
  111. e = service_rx.recv() => {
  112. if let Some(e) = e {
  113. match e {
  114. ServiceChange::ClientAdd(s)=> {
  115. let name = s.name.clone();
  116. let handle = ControlChannelHandle::new(
  117. s,
  118. self.config.remote_addr.clone(),
  119. self.transport.clone(),
  120. );
  121. let _ = self.service_handles.insert(name, handle);
  122. },
  123. ServiceChange::ClientDelete(s)=> {
  124. let _ = self.service_handles.remove(&s);
  125. },
  126. _ => ()
  127. }
  128. }
  129. }
  130. }
  131. }
  132. // Shutdown all services
  133. for (_, handle) in self.service_handles.drain() {
  134. handle.shutdown();
  135. }
  136. Ok(())
  137. }
  138. }
  139. struct RunDataChannelArgs<T: Transport> {
  140. session_key: Nonce,
  141. remote_addr: String,
  142. local_addr: String,
  143. connector: Arc<T>,
  144. }
  145. async fn do_data_channel_handshake<T: Transport>(
  146. args: Arc<RunDataChannelArgs<T>>,
  147. ) -> Result<T::Stream> {
  148. // Retry at least every 100ms, at most for 10 seconds
  149. let backoff = ExponentialBackoff {
  150. max_interval: Duration::from_millis(100),
  151. max_elapsed_time: Some(Duration::from_secs(10)),
  152. ..Default::default()
  153. };
  154. // FIXME: Respect control channel shutdown here
  155. // Connect to remote_addr
  156. let mut conn: T::Stream = backoff::future::retry_notify(
  157. backoff,
  158. || async {
  159. Ok(args
  160. .connector
  161. .connect(&args.remote_addr)
  162. .await
  163. .with_context(|| "Failed to connect to remote_addr")?)
  164. },
  165. |e, duration| {
  166. warn!("{:?}. Retry in {:?}", e, duration);
  167. },
  168. )
  169. .await?;
  170. // Send nonce
  171. let v: &[u8; HASH_WIDTH_IN_BYTES] = args.session_key[..].try_into().unwrap();
  172. let hello = Hello::DataChannelHello(CURRENT_PROTO_VERSION, v.to_owned());
  173. conn.write_all(&bincode::serialize(&hello).unwrap()).await?;
  174. conn.flush().await?;
  175. Ok(conn)
  176. }
  177. async fn run_data_channel<T: Transport>(args: Arc<RunDataChannelArgs<T>>) -> Result<()> {
  178. // Do the handshake
  179. let mut conn = do_data_channel_handshake(args.clone()).await?;
  180. // Forward
  181. match read_data_cmd(&mut conn).await? {
  182. DataChannelCmd::StartForwardTcp => {
  183. run_data_channel_for_tcp::<T>(conn, &args.local_addr).await?;
  184. }
  185. DataChannelCmd::StartForwardUdp => {
  186. run_data_channel_for_udp::<T>(conn, &args.local_addr).await?;
  187. }
  188. }
  189. Ok(())
  190. }
  191. // Simply copying back and forth for TCP
  192. #[instrument(skip(conn))]
  193. async fn run_data_channel_for_tcp<T: Transport>(
  194. mut conn: T::Stream,
  195. local_addr: &str,
  196. ) -> Result<()> {
  197. debug!("New data channel starts forwarding");
  198. let mut local = TcpStream::connect(local_addr)
  199. .await
  200. .with_context(|| "Failed to connect to local_addr")?;
  201. let _ = copy_bidirectional(&mut conn, &mut local).await;
  202. Ok(())
  203. }
  204. // Things get a little tricker when it gets to UDP because it's connection-less.
  205. // A UdpPortMap must be maintained for recent seen incoming address, giving them
  206. // each a local port, which is associated with a socket. So just the sender
  207. // to the socket will work fine for the map's value.
  208. type UdpPortMap = Arc<RwLock<HashMap<SocketAddr, mpsc::Sender<Bytes>>>>;
  209. #[instrument(skip(conn))]
  210. async fn run_data_channel_for_udp<T: Transport>(conn: T::Stream, local_addr: &str) -> Result<()> {
  211. debug!("New data channel starts forwarding");
  212. let port_map: UdpPortMap = Arc::new(RwLock::new(HashMap::new()));
  213. // The channel stores UdpTraffic that needs to be sent to the server
  214. let (outbound_tx, mut outbound_rx) = mpsc::channel::<UdpTraffic>(UDP_SENDQ_SIZE);
  215. // FIXME: https://github.com/tokio-rs/tls/issues/40
  216. // Maybe this is our concern
  217. let (mut rd, mut wr) = io::split(conn);
  218. // Keep sending items from the outbound channel to the server
  219. tokio::spawn(async move {
  220. while let Some(t) = outbound_rx.recv().await {
  221. trace!("outbound {:?}", t);
  222. if let Err(e) = t
  223. .write(&mut wr)
  224. .await
  225. .with_context(|| "Failed to forward UDP traffic to the server")
  226. {
  227. debug!("{:?}", e);
  228. break;
  229. }
  230. }
  231. });
  232. loop {
  233. // Read a packet from the server
  234. let hdr_len = rd.read_u8().await?;
  235. let packet = UdpTraffic::read(&mut rd, hdr_len)
  236. .await
  237. .with_context(|| "Failed to read UDPTraffic from the server")?;
  238. let m = port_map.read().await;
  239. if m.get(&packet.from).is_none() {
  240. // This packet is from a address we don't see for a while,
  241. // which is not in the UdpPortMap.
  242. // So set up a mapping (and a forwarder) for it
  243. // Drop the reader lock
  244. drop(m);
  245. // Grab the writer lock
  246. // This is the only thread that will try to grab the writer lock
  247. // So no need to worry about some other thread has already set up
  248. // the mapping between the gap of dropping the reader lock and
  249. // grabbing the writer lock
  250. let mut m = port_map.write().await;
  251. match udp_connect(local_addr).await {
  252. Ok(s) => {
  253. let (inbound_tx, inbound_rx) = mpsc::channel(UDP_SENDQ_SIZE);
  254. m.insert(packet.from, inbound_tx);
  255. tokio::spawn(run_udp_forwarder(
  256. s,
  257. inbound_rx,
  258. outbound_tx.clone(),
  259. packet.from,
  260. port_map.clone(),
  261. ));
  262. }
  263. Err(e) => {
  264. error!("{:?}", e);
  265. }
  266. }
  267. }
  268. // Now there should be a udp forwarder that can receive the packet
  269. let m = port_map.read().await;
  270. if let Some(tx) = m.get(&packet.from) {
  271. let _ = tx.send(packet.data).await;
  272. }
  273. }
  274. }
  275. // Run a UdpSocket for the visitor `from`
  276. #[instrument(skip_all, fields(from))]
  277. async fn run_udp_forwarder(
  278. s: UdpSocket,
  279. mut inbound_rx: mpsc::Receiver<Bytes>,
  280. outbount_tx: mpsc::Sender<UdpTraffic>,
  281. from: SocketAddr,
  282. port_map: UdpPortMap,
  283. ) -> Result<()> {
  284. debug!("Forwarder created");
  285. let mut buf = BytesMut::new();
  286. buf.resize(UDP_BUFFER_SIZE, 0);
  287. loop {
  288. tokio::select! {
  289. // Receive from the server
  290. data = inbound_rx.recv() => {
  291. if let Some(data) = data {
  292. s.send(&data).await?;
  293. } else {
  294. break;
  295. }
  296. },
  297. // Receive from the service
  298. val = s.recv(&mut buf) => {
  299. let len = match val {
  300. Ok(v) => v,
  301. Err(_) => {break;}
  302. };
  303. let t = UdpTraffic{
  304. from,
  305. data: Bytes::copy_from_slice(&buf[..len])
  306. };
  307. outbount_tx.send(t).await?;
  308. },
  309. // No traffic for the duration of UDP_TIMEOUT, clean up the state
  310. _ = time::sleep(Duration::from_secs(UDP_TIMEOUT)) => {
  311. break;
  312. }
  313. }
  314. }
  315. let mut port_map = port_map.write().await;
  316. port_map.remove(&from);
  317. debug!("Forwarder dropped");
  318. Ok(())
  319. }
  320. // Control channel, using T as the transport layer
  321. struct ControlChannel<T: Transport> {
  322. digest: ServiceDigest, // SHA256 of the service name
  323. service: ClientServiceConfig, // `[client.services.foo]` config block
  324. shutdown_rx: oneshot::Receiver<u8>, // Receives the shutdown signal
  325. remote_addr: String, // `client.remote_addr`
  326. transport: Arc<T>, // Wrapper around the transport layer
  327. }
  328. // Handle of a control channel
  329. // Dropping it will also drop the actual control channel
  330. struct ControlChannelHandle {
  331. shutdown_tx: oneshot::Sender<u8>,
  332. }
  333. impl<T: 'static + Transport> ControlChannel<T> {
  334. #[instrument(skip_all)]
  335. async fn run(&mut self) -> Result<()> {
  336. let mut conn = self
  337. .transport
  338. .connect(&self.remote_addr)
  339. .await
  340. .with_context(|| format!("Failed to connect to the server: {}", &self.remote_addr))?;
  341. // Send hello
  342. debug!("Sending hello");
  343. let hello_send =
  344. Hello::ControlChannelHello(CURRENT_PROTO_VERSION, self.digest[..].try_into().unwrap());
  345. conn.write_all(&bincode::serialize(&hello_send).unwrap())
  346. .await?;
  347. conn.flush().await?;
  348. // Read hello
  349. debug!("Reading hello");
  350. let nonce = match read_hello(&mut conn).await? {
  351. ControlChannelHello(_, d) => d,
  352. _ => {
  353. bail!("Unexpected type of hello");
  354. }
  355. };
  356. // Send auth
  357. debug!("Sending auth");
  358. let mut concat = Vec::from(self.service.token.as_ref().unwrap().as_bytes());
  359. concat.extend_from_slice(&nonce);
  360. let session_key = protocol::digest(&concat);
  361. let auth = Auth(session_key);
  362. conn.write_all(&bincode::serialize(&auth).unwrap()).await?;
  363. conn.flush().await?;
  364. // Read ack
  365. debug!("Reading ack");
  366. match read_ack(&mut conn).await? {
  367. Ack::Ok => {}
  368. v => {
  369. return Err(anyhow!("{}", v))
  370. .with_context(|| format!("Authentication failed: {}", self.service.name));
  371. }
  372. }
  373. // Channel ready
  374. info!("Control channel established");
  375. let remote_addr = self.remote_addr.clone();
  376. let local_addr = self.service.local_addr.clone();
  377. let data_ch_args = Arc::new(RunDataChannelArgs {
  378. session_key,
  379. remote_addr,
  380. local_addr,
  381. connector: self.transport.clone(),
  382. });
  383. loop {
  384. tokio::select! {
  385. val = read_control_cmd(&mut conn) => {
  386. let val = val?;
  387. debug!( "Received {:?}", val);
  388. match val {
  389. ControlChannelCmd::CreateDataChannel => {
  390. let args = data_ch_args.clone();
  391. tokio::spawn(async move {
  392. if let Err(e) = run_data_channel(args).await.with_context(|| "Failed to run the data channel") {
  393. error!("{:?}", e);
  394. }
  395. }.instrument(Span::current()));
  396. }
  397. }
  398. },
  399. _ = &mut self.shutdown_rx => {
  400. break;
  401. }
  402. }
  403. }
  404. info!("Control channel shutdown");
  405. Ok(())
  406. }
  407. }
  408. impl ControlChannelHandle {
  409. #[instrument(skip_all, fields(service = %service.name))]
  410. fn new<T: 'static + Transport>(
  411. service: ClientServiceConfig,
  412. remote_addr: String,
  413. transport: Arc<T>,
  414. ) -> ControlChannelHandle {
  415. let digest = protocol::digest(service.name.as_bytes());
  416. info!("Starting {}", hex::encode(digest));
  417. let (shutdown_tx, shutdown_rx) = oneshot::channel();
  418. let mut s = ControlChannel {
  419. digest,
  420. service,
  421. shutdown_rx,
  422. remote_addr,
  423. transport,
  424. };
  425. tokio::spawn(
  426. async move {
  427. while let Err(err) = s
  428. .run()
  429. .await
  430. .with_context(|| "Failed to run the control channel")
  431. {
  432. if s.shutdown_rx.try_recv() != Err(oneshot::error::TryRecvError::Empty) {
  433. break;
  434. }
  435. let duration = Duration::from_secs(1);
  436. error!("{:?}\n\nRetry in {:?}...", err, duration);
  437. time::sleep(duration).await;
  438. }
  439. }
  440. .instrument(Span::current()),
  441. );
  442. ControlChannelHandle { shutdown_tx }
  443. }
  444. fn shutdown(self) {
  445. // A send failure shows that the actor has already shutdown.
  446. let _ = self.shutdown_tx.send(0u8);
  447. }
  448. }