client.rs 17 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527
  1. use crate::config::{ClientConfig, ClientServiceConfig, Config, TransportType};
  2. use crate::config_watcher::ServiceChange;
  3. use crate::helper::udp_connect;
  4. use crate::protocol::Hello::{self, *};
  5. use crate::protocol::{
  6. self, read_ack, read_control_cmd, read_data_cmd, read_hello, Ack, Auth, ControlChannelCmd,
  7. DataChannelCmd, UdpTraffic, CURRENT_PROTO_VERSION, HASH_WIDTH_IN_BYTES,
  8. };
  9. use crate::transport::{SocketOpts, TcpTransport, Transport};
  10. use anyhow::{anyhow, bail, Context, Result};
  11. use backoff::backoff::Backoff;
  12. use backoff::ExponentialBackoff;
  13. use bytes::{Bytes, BytesMut};
  14. use std::collections::HashMap;
  15. use std::net::SocketAddr;
  16. use std::sync::Arc;
  17. use tokio::io::{self, copy_bidirectional, AsyncReadExt, AsyncWriteExt};
  18. use tokio::net::{TcpStream, UdpSocket};
  19. use tokio::sync::{broadcast, mpsc, oneshot, RwLock};
  20. use tokio::time::{self, Duration, Instant};
  21. use tracing::{debug, error, info, instrument, trace, warn, Instrument, Span};
  22. #[cfg(feature = "noise")]
  23. use crate::transport::NoiseTransport;
  24. #[cfg(feature = "tls")]
  25. use crate::transport::TlsTransport;
  26. use crate::constants::{run_control_chan_backoff, UDP_BUFFER_SIZE, UDP_SENDQ_SIZE, UDP_TIMEOUT};
  27. // The entrypoint of running a client
  28. pub async fn run_client(
  29. config: &Config,
  30. shutdown_rx: broadcast::Receiver<bool>,
  31. service_rx: mpsc::Receiver<ServiceChange>,
  32. ) -> Result<()> {
  33. let config = match &config.client {
  34. Some(v) => v,
  35. None => {
  36. return Err(anyhow!("Try to run as a client, but the configuration is missing. Please add the `[client]` block"))
  37. }
  38. };
  39. match config.transport.transport_type {
  40. TransportType::Tcp => {
  41. let mut client = Client::<TcpTransport>::from(config).await?;
  42. client.run(shutdown_rx, service_rx).await
  43. }
  44. TransportType::Tls => {
  45. #[cfg(feature = "tls")]
  46. {
  47. let mut client = Client::<TlsTransport>::from(config).await?;
  48. client.run(shutdown_rx, service_rx).await
  49. }
  50. #[cfg(not(feature = "tls"))]
  51. crate::helper::feature_not_compile("tls")
  52. }
  53. TransportType::Noise => {
  54. #[cfg(feature = "noise")]
  55. {
  56. let mut client = Client::<NoiseTransport>::from(config).await?;
  57. client.run(shutdown_rx, service_rx).await
  58. }
  59. #[cfg(not(feature = "noise"))]
  60. crate::helper::feature_not_compile("noise")
  61. }
  62. }
  63. }
  64. type ServiceDigest = protocol::Digest;
  65. type Nonce = protocol::Digest;
  66. // Holds the state of a client
  67. struct Client<'a, T: Transport> {
  68. config: &'a ClientConfig,
  69. service_handles: HashMap<String, ControlChannelHandle>,
  70. transport: Arc<T>,
  71. }
  72. impl<'a, T: 'static + Transport> Client<'a, T> {
  73. // Create a Client from `[client]` config block
  74. async fn from(config: &'a ClientConfig) -> Result<Client<'a, T>> {
  75. Ok(Client {
  76. config,
  77. service_handles: HashMap::new(),
  78. transport: Arc::new(
  79. T::new(&config.transport).with_context(|| "Failed to create the transport")?,
  80. ),
  81. })
  82. }
  83. // The entrypoint of Client
  84. async fn run(
  85. &mut self,
  86. mut shutdown_rx: broadcast::Receiver<bool>,
  87. mut service_rx: mpsc::Receiver<ServiceChange>,
  88. ) -> Result<()> {
  89. for (name, config) in &self.config.services {
  90. // Create a control channel for each service defined
  91. let handle = ControlChannelHandle::new(
  92. (*config).clone(),
  93. self.config.remote_addr.clone(),
  94. self.transport.clone(),
  95. );
  96. self.service_handles.insert(name.clone(), handle);
  97. }
  98. // Wait for the shutdown signal
  99. loop {
  100. tokio::select! {
  101. val = shutdown_rx.recv() => {
  102. match val {
  103. Ok(_) => {}
  104. Err(err) => {
  105. error!("Unable to listen for shutdown signal: {}", err);
  106. }
  107. }
  108. break;
  109. },
  110. e = service_rx.recv() => {
  111. if let Some(e) = e {
  112. match e {
  113. ServiceChange::ClientAdd(s)=> {
  114. let name = s.name.clone();
  115. let handle = ControlChannelHandle::new(
  116. s,
  117. self.config.remote_addr.clone(),
  118. self.transport.clone(),
  119. );
  120. let _ = self.service_handles.insert(name, handle);
  121. },
  122. ServiceChange::ClientDelete(s)=> {
  123. let _ = self.service_handles.remove(&s);
  124. },
  125. _ => ()
  126. }
  127. }
  128. }
  129. }
  130. }
  131. // Shutdown all services
  132. for (_, handle) in self.service_handles.drain() {
  133. handle.shutdown();
  134. }
  135. Ok(())
  136. }
  137. }
  138. struct RunDataChannelArgs<T: Transport> {
  139. session_key: Nonce,
  140. remote_addr: String,
  141. local_addr: String,
  142. connector: Arc<T>,
  143. socket_opts: SocketOpts,
  144. }
  145. async fn do_data_channel_handshake<T: Transport>(
  146. args: Arc<RunDataChannelArgs<T>>,
  147. ) -> Result<T::Stream> {
  148. // Retry at least every 100ms, at most for 10 seconds
  149. let backoff = ExponentialBackoff {
  150. max_interval: Duration::from_millis(100),
  151. max_elapsed_time: Some(Duration::from_secs(10)),
  152. ..Default::default()
  153. };
  154. // FIXME: Respect control channel shutdown here
  155. // Connect to remote_addr
  156. let mut conn: T::Stream = backoff::future::retry_notify(
  157. backoff,
  158. || async {
  159. let conn = args
  160. .connector
  161. .connect(&args.remote_addr)
  162. .await
  163. .with_context(|| format!("Failed to connect to {}", &args.remote_addr))?;
  164. T::hint(&conn, args.socket_opts);
  165. Ok(conn)
  166. },
  167. |e, duration| {
  168. warn!("{:#}. Retry in {:?}", e, duration);
  169. },
  170. )
  171. .await?;
  172. // Send nonce
  173. let v: &[u8; HASH_WIDTH_IN_BYTES] = args.session_key[..].try_into().unwrap();
  174. let hello = Hello::DataChannelHello(CURRENT_PROTO_VERSION, v.to_owned());
  175. conn.write_all(&bincode::serialize(&hello).unwrap()).await?;
  176. conn.flush().await?;
  177. Ok(conn)
  178. }
  179. async fn run_data_channel<T: Transport>(args: Arc<RunDataChannelArgs<T>>) -> Result<()> {
  180. // Do the handshake
  181. let mut conn = do_data_channel_handshake(args.clone()).await?;
  182. // Forward
  183. match read_data_cmd(&mut conn).await? {
  184. DataChannelCmd::StartForwardTcp => {
  185. run_data_channel_for_tcp::<T>(conn, &args.local_addr).await?;
  186. }
  187. DataChannelCmd::StartForwardUdp => {
  188. run_data_channel_for_udp::<T>(conn, &args.local_addr).await?;
  189. }
  190. }
  191. Ok(())
  192. }
  193. // Simply copying back and forth for TCP
  194. #[instrument(skip(conn))]
  195. async fn run_data_channel_for_tcp<T: Transport>(
  196. mut conn: T::Stream,
  197. local_addr: &str,
  198. ) -> Result<()> {
  199. debug!("New data channel starts forwarding");
  200. let mut local = TcpStream::connect(local_addr)
  201. .await
  202. .with_context(|| format!("Failed to connect to {}", local_addr))?;
  203. let _ = copy_bidirectional(&mut conn, &mut local).await;
  204. Ok(())
  205. }
  206. // Things get a little tricker when it gets to UDP because it's connection-less.
  207. // A UdpPortMap must be maintained for recent seen incoming address, giving them
  208. // each a local port, which is associated with a socket. So just the sender
  209. // to the socket will work fine for the map's value.
  210. type UdpPortMap = Arc<RwLock<HashMap<SocketAddr, mpsc::Sender<Bytes>>>>;
  211. #[instrument(skip(conn))]
  212. async fn run_data_channel_for_udp<T: Transport>(conn: T::Stream, local_addr: &str) -> Result<()> {
  213. debug!("New data channel starts forwarding");
  214. let port_map: UdpPortMap = Arc::new(RwLock::new(HashMap::new()));
  215. // The channel stores UdpTraffic that needs to be sent to the server
  216. let (outbound_tx, mut outbound_rx) = mpsc::channel::<UdpTraffic>(UDP_SENDQ_SIZE);
  217. // FIXME: https://github.com/tokio-rs/tls/issues/40
  218. // Maybe this is our concern
  219. let (mut rd, mut wr) = io::split(conn);
  220. // Keep sending items from the outbound channel to the server
  221. tokio::spawn(async move {
  222. while let Some(t) = outbound_rx.recv().await {
  223. trace!("outbound {:?}", t);
  224. if let Err(e) = t
  225. .write(&mut wr)
  226. .await
  227. .with_context(|| "Failed to forward UDP traffic to the server")
  228. {
  229. debug!("{:?}", e);
  230. break;
  231. }
  232. }
  233. });
  234. loop {
  235. // Read a packet from the server
  236. let hdr_len = rd.read_u8().await?;
  237. let packet = UdpTraffic::read(&mut rd, hdr_len)
  238. .await
  239. .with_context(|| "Failed to read UDPTraffic from the server")?;
  240. let m = port_map.read().await;
  241. if m.get(&packet.from).is_none() {
  242. // This packet is from a address we don't see for a while,
  243. // which is not in the UdpPortMap.
  244. // So set up a mapping (and a forwarder) for it
  245. // Drop the reader lock
  246. drop(m);
  247. // Grab the writer lock
  248. // This is the only thread that will try to grab the writer lock
  249. // So no need to worry about some other thread has already set up
  250. // the mapping between the gap of dropping the reader lock and
  251. // grabbing the writer lock
  252. let mut m = port_map.write().await;
  253. match udp_connect(local_addr).await {
  254. Ok(s) => {
  255. let (inbound_tx, inbound_rx) = mpsc::channel(UDP_SENDQ_SIZE);
  256. m.insert(packet.from, inbound_tx);
  257. tokio::spawn(run_udp_forwarder(
  258. s,
  259. inbound_rx,
  260. outbound_tx.clone(),
  261. packet.from,
  262. port_map.clone(),
  263. ));
  264. }
  265. Err(e) => {
  266. error!("{:#}", e);
  267. }
  268. }
  269. }
  270. // Now there should be a udp forwarder that can receive the packet
  271. let m = port_map.read().await;
  272. if let Some(tx) = m.get(&packet.from) {
  273. let _ = tx.send(packet.data).await;
  274. }
  275. }
  276. }
  277. // Run a UdpSocket for the visitor `from`
  278. #[instrument(skip_all, fields(from))]
  279. async fn run_udp_forwarder(
  280. s: UdpSocket,
  281. mut inbound_rx: mpsc::Receiver<Bytes>,
  282. outbount_tx: mpsc::Sender<UdpTraffic>,
  283. from: SocketAddr,
  284. port_map: UdpPortMap,
  285. ) -> Result<()> {
  286. debug!("Forwarder created");
  287. let mut buf = BytesMut::new();
  288. buf.resize(UDP_BUFFER_SIZE, 0);
  289. loop {
  290. tokio::select! {
  291. // Receive from the server
  292. data = inbound_rx.recv() => {
  293. if let Some(data) = data {
  294. s.send(&data).await?;
  295. } else {
  296. break;
  297. }
  298. },
  299. // Receive from the service
  300. val = s.recv(&mut buf) => {
  301. let len = match val {
  302. Ok(v) => v,
  303. Err(_) => {break;}
  304. };
  305. let t = UdpTraffic{
  306. from,
  307. data: Bytes::copy_from_slice(&buf[..len])
  308. };
  309. outbount_tx.send(t).await?;
  310. },
  311. // No traffic for the duration of UDP_TIMEOUT, clean up the state
  312. _ = time::sleep(Duration::from_secs(UDP_TIMEOUT)) => {
  313. break;
  314. }
  315. }
  316. }
  317. let mut port_map = port_map.write().await;
  318. port_map.remove(&from);
  319. debug!("Forwarder dropped");
  320. Ok(())
  321. }
  322. // Control channel, using T as the transport layer
  323. struct ControlChannel<T: Transport> {
  324. digest: ServiceDigest, // SHA256 of the service name
  325. service: ClientServiceConfig, // `[client.services.foo]` config block
  326. shutdown_rx: oneshot::Receiver<u8>, // Receives the shutdown signal
  327. remote_addr: String, // `client.remote_addr`
  328. transport: Arc<T>, // Wrapper around the transport layer
  329. }
  330. // Handle of a control channel
  331. // Dropping it will also drop the actual control channel
  332. struct ControlChannelHandle {
  333. shutdown_tx: oneshot::Sender<u8>,
  334. }
  335. impl<T: 'static + Transport> ControlChannel<T> {
  336. #[instrument(skip_all)]
  337. async fn run(&mut self) -> Result<()> {
  338. let mut conn = self
  339. .transport
  340. .connect(&self.remote_addr)
  341. .await
  342. .with_context(|| format!("Failed to connect to {}", &self.remote_addr))?;
  343. T::hint(&conn, SocketOpts::for_control_channel());
  344. // Send hello
  345. debug!("Sending hello");
  346. let hello_send =
  347. Hello::ControlChannelHello(CURRENT_PROTO_VERSION, self.digest[..].try_into().unwrap());
  348. conn.write_all(&bincode::serialize(&hello_send).unwrap())
  349. .await?;
  350. conn.flush().await?;
  351. // Read hello
  352. debug!("Reading hello");
  353. let nonce = match read_hello(&mut conn).await? {
  354. ControlChannelHello(_, d) => d,
  355. _ => {
  356. bail!("Unexpected type of hello");
  357. }
  358. };
  359. // Send auth
  360. debug!("Sending auth");
  361. let mut concat = Vec::from(self.service.token.as_ref().unwrap().as_bytes());
  362. concat.extend_from_slice(&nonce);
  363. let session_key = protocol::digest(&concat);
  364. let auth = Auth(session_key);
  365. conn.write_all(&bincode::serialize(&auth).unwrap()).await?;
  366. conn.flush().await?;
  367. // Read ack
  368. debug!("Reading ack");
  369. match read_ack(&mut conn).await? {
  370. Ack::Ok => {}
  371. v => {
  372. return Err(anyhow!("{}", v))
  373. .with_context(|| format!("Authentication failed: {}", self.service.name));
  374. }
  375. }
  376. // Channel ready
  377. info!("Control channel established");
  378. let remote_addr = self.remote_addr.clone();
  379. let local_addr = self.service.local_addr.clone();
  380. // Socket options for the data channel
  381. let socket_opts = SocketOpts::from_client_cfg(&self.service);
  382. let data_ch_args = Arc::new(RunDataChannelArgs {
  383. session_key,
  384. remote_addr,
  385. local_addr,
  386. connector: self.transport.clone(),
  387. socket_opts,
  388. });
  389. loop {
  390. tokio::select! {
  391. val = read_control_cmd(&mut conn) => {
  392. let val = val?;
  393. debug!( "Received {:?}", val);
  394. match val {
  395. ControlChannelCmd::CreateDataChannel => {
  396. let args = data_ch_args.clone();
  397. tokio::spawn(async move {
  398. if let Err(e) = run_data_channel(args).await.with_context(|| "Failed to run the data channel") {
  399. warn!("{:#}", e);
  400. }
  401. }.instrument(Span::current()));
  402. }
  403. }
  404. },
  405. _ = &mut self.shutdown_rx => {
  406. break;
  407. }
  408. }
  409. }
  410. info!("Control channel shutdown");
  411. Ok(())
  412. }
  413. }
  414. impl ControlChannelHandle {
  415. #[instrument(name="handle", skip_all, fields(service = %service.name))]
  416. fn new<T: 'static + Transport>(
  417. service: ClientServiceConfig,
  418. remote_addr: String,
  419. transport: Arc<T>,
  420. ) -> ControlChannelHandle {
  421. let digest = protocol::digest(service.name.as_bytes());
  422. info!("Starting {}", hex::encode(digest));
  423. let (shutdown_tx, shutdown_rx) = oneshot::channel();
  424. let mut s = ControlChannel {
  425. digest,
  426. service,
  427. shutdown_rx,
  428. remote_addr,
  429. transport,
  430. };
  431. tokio::spawn(
  432. async move {
  433. let mut backoff = run_control_chan_backoff();
  434. let mut start = Instant::now();
  435. while let Err(err) = s
  436. .run()
  437. .await
  438. .with_context(|| "Failed to run the control channel")
  439. {
  440. if s.shutdown_rx.try_recv() != Err(oneshot::error::TryRecvError::Empty) {
  441. break;
  442. }
  443. if start.elapsed() > Duration::from_secs(3) {
  444. // The client runs for at least 3 secs and then disconnects
  445. // Retry immediately
  446. backoff.reset();
  447. error!("{:#}. Retry...", err);
  448. } else if let Some(duration) = backoff.next_backoff() {
  449. error!("{:#}. Retry in {:?}...", err, duration);
  450. time::sleep(duration).await;
  451. } else {
  452. // Should never reach
  453. panic!("{:#}. Break", err);
  454. }
  455. start = Instant::now();
  456. }
  457. }
  458. .instrument(Span::current()),
  459. );
  460. ControlChannelHandle { shutdown_tx }
  461. }
  462. fn shutdown(self) {
  463. // A send failure shows that the actor has already shutdown.
  464. let _ = self.shutdown_tx.send(0u8);
  465. }
  466. }