client.rs 9.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286
  1. use crate::config::{ClientConfig, ClientServiceConfig, Config, TransportType};
  2. use crate::protocol::Hello::{self, *};
  3. use crate::protocol::{
  4. self, read_ack, read_control_cmd, read_data_cmd, read_hello, Ack, Auth, ControlChannelCmd,
  5. DataChannelCmd, CURRENT_PROTO_VRESION, HASH_WIDTH_IN_BYTES,
  6. };
  7. use crate::transport::{TcpTransport, Transport};
  8. use anyhow::{anyhow, bail, Context, Result};
  9. use backoff::ExponentialBackoff;
  10. use std::collections::HashMap;
  11. use std::sync::Arc;
  12. use tokio::io::{copy_bidirectional, AsyncWriteExt};
  13. use tokio::net::TcpStream;
  14. use tokio::sync::{broadcast, oneshot};
  15. use tokio::time::{self, Duration};
  16. use tracing::{debug, error, info, instrument, Instrument, Span};
  17. #[cfg(feature = "tls")]
  18. use crate::transport::TlsTransport;
  19. // The entrypoint of running a client
  20. pub async fn run_client(config: &Config, shutdown_rx: broadcast::Receiver<bool>) -> Result<()> {
  21. let config = match &config.client {
  22. Some(v) => v,
  23. None => {
  24. return Err(anyhow!("Try to run as a client, but the configuration is missing. Please add the `[client]` block"))
  25. }
  26. };
  27. match config.transport.transport_type {
  28. TransportType::Tcp => {
  29. let mut client = Client::<TcpTransport>::from(config).await?;
  30. client.run(shutdown_rx).await
  31. }
  32. TransportType::Tls => {
  33. #[cfg(feature = "tls")]
  34. {
  35. let mut client = Client::<TlsTransport>::from(config).await?;
  36. client.run(shutdown_rx).await
  37. }
  38. #[cfg(not(feature = "tls"))]
  39. crate::helper::feature_not_compile("tls")
  40. }
  41. }
  42. }
  43. type ServiceDigest = protocol::Digest;
  44. type Nonce = protocol::Digest;
  45. // Holds the state of a client
  46. struct Client<'a, T: Transport> {
  47. config: &'a ClientConfig,
  48. service_handles: HashMap<String, ControlChannelHandle>,
  49. transport: Arc<T>,
  50. }
  51. impl<'a, T: 'static + Transport> Client<'a, T> {
  52. // Create a Client from `[client]` config block
  53. async fn from(config: &'a ClientConfig) -> Result<Client<'a, T>> {
  54. Ok(Client {
  55. config,
  56. service_handles: HashMap::new(),
  57. transport: Arc::new(
  58. *T::new(&config.transport)
  59. .await
  60. .with_context(|| "Failed to create the transport")?,
  61. ),
  62. })
  63. }
  64. // The entrypoint of Client
  65. async fn run(&mut self, mut shutdown_rx: broadcast::Receiver<bool>) -> Result<()> {
  66. for (name, config) in &self.config.services {
  67. // Create a control channel for each service defined
  68. let handle = ControlChannelHandle::new(
  69. (*config).clone(),
  70. self.config.remote_addr.clone(),
  71. self.transport.clone(),
  72. );
  73. self.service_handles.insert(name.clone(), handle);
  74. }
  75. // TODO: Maybe wait for a config change signal for hot reloading
  76. // Wait for the shutdown signal
  77. loop {
  78. tokio::select! {
  79. val = shutdown_rx.recv() => {
  80. match val {
  81. Ok(_) => {}
  82. Err(err) => {
  83. error!("Unable to listen for shutdown signal: {}", err);
  84. }
  85. }
  86. break;
  87. },
  88. }
  89. }
  90. // Shutdown all services
  91. for (_, handle) in self.service_handles.drain() {
  92. handle.shutdown();
  93. }
  94. Ok(())
  95. }
  96. }
  97. struct RunDataChannelArgs<T: Transport> {
  98. session_key: Nonce,
  99. remote_addr: String,
  100. local_addr: String,
  101. connector: Arc<T>,
  102. }
  103. async fn run_data_channel<T: Transport>(args: Arc<RunDataChannelArgs<T>>) -> Result<()> {
  104. // Retry at least every 100ms, at most for 10 seconds
  105. let backoff = ExponentialBackoff {
  106. max_interval: Duration::from_millis(100),
  107. max_elapsed_time: Some(Duration::from_secs(10)),
  108. ..Default::default()
  109. };
  110. // Connect to remote_addr
  111. let mut conn: T::Stream = backoff::future::retry(backoff, || async {
  112. Ok(args
  113. .connector
  114. .connect(&args.remote_addr)
  115. .await
  116. .with_context(|| "Failed to connect to remote_addr")?)
  117. })
  118. .await?;
  119. // Send nonce
  120. let v: &[u8; HASH_WIDTH_IN_BYTES] = args.session_key[..].try_into().unwrap();
  121. let hello = Hello::DataChannelHello(CURRENT_PROTO_VRESION, v.to_owned());
  122. conn.write_all(&bincode::serialize(&hello).unwrap()).await?;
  123. // Forward
  124. match read_data_cmd(&mut conn).await? {
  125. DataChannelCmd::StartForward => {
  126. let mut local = TcpStream::connect(&args.local_addr)
  127. .await
  128. .with_context(|| "Failed to conenct to local_addr")?;
  129. let _ = copy_bidirectional(&mut conn, &mut local).await;
  130. }
  131. }
  132. Ok(())
  133. }
  134. // Control channel, using T as the transport layer
  135. struct ControlChannel<T: Transport> {
  136. digest: ServiceDigest, // SHA256 of the service name
  137. service: ClientServiceConfig, // `[client.services.foo]` config block
  138. shutdown_rx: oneshot::Receiver<u8>, // Receives the shutdown signal
  139. remote_addr: String, // `client.remote_addr`
  140. transport: Arc<T>, // Wrapper around the transport layer
  141. }
  142. // Handle of a control channel
  143. // Dropping it will also drop the actual control channel
  144. struct ControlChannelHandle {
  145. shutdown_tx: oneshot::Sender<u8>,
  146. }
  147. impl<T: 'static + Transport> ControlChannel<T> {
  148. #[instrument(skip(self), fields(service=%self.service.name))]
  149. async fn run(&mut self) -> Result<()> {
  150. let mut conn = self
  151. .transport
  152. .connect(&self.remote_addr)
  153. .await
  154. .with_context(|| format!("Failed to connect to the server: {}", &self.remote_addr))?;
  155. // Send hello
  156. let hello_send =
  157. Hello::ControlChannelHello(CURRENT_PROTO_VRESION, self.digest[..].try_into().unwrap());
  158. conn.write_all(&bincode::serialize(&hello_send).unwrap())
  159. .await?;
  160. // Read hello))
  161. let nonce = match read_hello(&mut conn)
  162. .await
  163. .with_context(|| "Failed to read hello from the server")?
  164. {
  165. ControlChannelHello(_, d) => d,
  166. _ => {
  167. bail!("Unexpected type of hello");
  168. }
  169. };
  170. // Send auth
  171. let mut concat = Vec::from(self.service.token.as_ref().unwrap().as_bytes());
  172. concat.extend_from_slice(&nonce);
  173. let session_key = protocol::digest(&concat);
  174. let auth = Auth(session_key);
  175. conn.write_all(&bincode::serialize(&auth).unwrap()).await?;
  176. // Read ack
  177. match read_ack(&mut conn).await? {
  178. Ack::Ok => {}
  179. v => {
  180. return Err(anyhow!("{}", v))
  181. .with_context(|| format!("Authentication failed: {}", self.service.name));
  182. }
  183. }
  184. // Channel ready
  185. info!("Control channel established");
  186. let remote_addr = self.remote_addr.clone();
  187. let local_addr = self.service.local_addr.clone();
  188. let data_ch_args = Arc::new(RunDataChannelArgs {
  189. session_key,
  190. remote_addr,
  191. local_addr,
  192. connector: self.transport.clone(),
  193. });
  194. loop {
  195. tokio::select! {
  196. val = read_control_cmd(&mut conn) => {
  197. let val = val?;
  198. debug!( "Received {:?}", val);
  199. match val {
  200. ControlChannelCmd::CreateDataChannel => {
  201. let args = data_ch_args.clone();
  202. tokio::spawn(async move {
  203. if let Err(e) = run_data_channel(args).await.with_context(|| "Failed to run the data channel") {
  204. error!("{:?}", e);
  205. }
  206. }.instrument(Span::current()));
  207. }
  208. }
  209. },
  210. _ = &mut self.shutdown_rx => {
  211. info!( "Shutting down gracefully...");
  212. break;
  213. }
  214. }
  215. }
  216. Ok(())
  217. }
  218. }
  219. impl ControlChannelHandle {
  220. #[instrument(skip_all, fields(service = %service.name))]
  221. fn new<T: 'static + Transport>(
  222. service: ClientServiceConfig,
  223. remote_addr: String,
  224. transport: Arc<T>,
  225. ) -> ControlChannelHandle {
  226. let digest = protocol::digest(service.name.as_bytes());
  227. let (shutdown_tx, shutdown_rx) = oneshot::channel();
  228. let mut s = ControlChannel {
  229. digest,
  230. service,
  231. shutdown_rx,
  232. remote_addr,
  233. transport,
  234. };
  235. tokio::spawn(
  236. async move {
  237. while let Err(err) = s
  238. .run()
  239. .await
  240. .with_context(|| "Failed to run the control channel")
  241. {
  242. let duration = Duration::from_secs(1);
  243. error!("{:?}\n\nRetry in {:?}...", err, duration);
  244. time::sleep(duration).await;
  245. }
  246. }
  247. .instrument(Span::current()),
  248. );
  249. ControlChannelHandle { shutdown_tx }
  250. }
  251. fn shutdown(self) {
  252. // A send failure shows that the actor has already shutdown.
  253. let _ = self.shutdown_tx.send(0u8);
  254. }
  255. }