| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286 |
- use crate::config::{ClientConfig, ClientServiceConfig, Config, TransportType};
- use crate::protocol::Hello::{self, *};
- use crate::protocol::{
- self, read_ack, read_control_cmd, read_data_cmd, read_hello, Ack, Auth, ControlChannelCmd,
- DataChannelCmd, CURRENT_PROTO_VRESION, HASH_WIDTH_IN_BYTES,
- };
- use crate::transport::{TcpTransport, Transport};
- use anyhow::{anyhow, bail, Context, Result};
- use backoff::ExponentialBackoff;
- use std::collections::HashMap;
- use std::sync::Arc;
- use tokio::io::{copy_bidirectional, AsyncWriteExt};
- use tokio::net::TcpStream;
- use tokio::sync::{broadcast, oneshot};
- use tokio::time::{self, Duration};
- use tracing::{debug, error, info, instrument, Instrument, Span};
- #[cfg(feature = "tls")]
- use crate::transport::TlsTransport;
- // The entrypoint of running a client
- pub async fn run_client(config: &Config, shutdown_rx: broadcast::Receiver<bool>) -> Result<()> {
- let config = match &config.client {
- Some(v) => v,
- None => {
- return Err(anyhow!("Try to run as a client, but the configuration is missing. Please add the `[client]` block"))
- }
- };
- match config.transport.transport_type {
- TransportType::Tcp => {
- let mut client = Client::<TcpTransport>::from(config).await?;
- client.run(shutdown_rx).await
- }
- TransportType::Tls => {
- #[cfg(feature = "tls")]
- {
- let mut client = Client::<TlsTransport>::from(config).await?;
- client.run(shutdown_rx).await
- }
- #[cfg(not(feature = "tls"))]
- crate::helper::feature_not_compile("tls")
- }
- }
- }
- type ServiceDigest = protocol::Digest;
- type Nonce = protocol::Digest;
- // Holds the state of a client
- struct Client<'a, T: Transport> {
- config: &'a ClientConfig,
- service_handles: HashMap<String, ControlChannelHandle>,
- transport: Arc<T>,
- }
- impl<'a, T: 'static + Transport> Client<'a, T> {
- // Create a Client from `[client]` config block
- async fn from(config: &'a ClientConfig) -> Result<Client<'a, T>> {
- Ok(Client {
- config,
- service_handles: HashMap::new(),
- transport: Arc::new(
- *T::new(&config.transport)
- .await
- .with_context(|| "Failed to create the transport")?,
- ),
- })
- }
- // The entrypoint of Client
- async fn run(&mut self, mut shutdown_rx: broadcast::Receiver<bool>) -> Result<()> {
- for (name, config) in &self.config.services {
- // Create a control channel for each service defined
- let handle = ControlChannelHandle::new(
- (*config).clone(),
- self.config.remote_addr.clone(),
- self.transport.clone(),
- );
- self.service_handles.insert(name.clone(), handle);
- }
- // TODO: Maybe wait for a config change signal for hot reloading
- // Wait for the shutdown signal
- loop {
- tokio::select! {
- val = shutdown_rx.recv() => {
- match val {
- Ok(_) => {}
- Err(err) => {
- error!("Unable to listen for shutdown signal: {}", err);
- }
- }
- break;
- },
- }
- }
- // Shutdown all services
- for (_, handle) in self.service_handles.drain() {
- handle.shutdown();
- }
- Ok(())
- }
- }
- struct RunDataChannelArgs<T: Transport> {
- session_key: Nonce,
- remote_addr: String,
- local_addr: String,
- connector: Arc<T>,
- }
- async fn run_data_channel<T: Transport>(args: Arc<RunDataChannelArgs<T>>) -> Result<()> {
- // Retry at least every 100ms, at most for 10 seconds
- let backoff = ExponentialBackoff {
- max_interval: Duration::from_millis(100),
- max_elapsed_time: Some(Duration::from_secs(10)),
- ..Default::default()
- };
- // Connect to remote_addr
- let mut conn: T::Stream = backoff::future::retry(backoff, || async {
- Ok(args
- .connector
- .connect(&args.remote_addr)
- .await
- .with_context(|| "Failed to connect to remote_addr")?)
- })
- .await?;
- // Send nonce
- let v: &[u8; HASH_WIDTH_IN_BYTES] = args.session_key[..].try_into().unwrap();
- let hello = Hello::DataChannelHello(CURRENT_PROTO_VRESION, v.to_owned());
- conn.write_all(&bincode::serialize(&hello).unwrap()).await?;
- // Forward
- match read_data_cmd(&mut conn).await? {
- DataChannelCmd::StartForward => {
- let mut local = TcpStream::connect(&args.local_addr)
- .await
- .with_context(|| "Failed to conenct to local_addr")?;
- let _ = copy_bidirectional(&mut conn, &mut local).await;
- }
- }
- Ok(())
- }
- // Control channel, using T as the transport layer
- struct ControlChannel<T: Transport> {
- digest: ServiceDigest, // SHA256 of the service name
- service: ClientServiceConfig, // `[client.services.foo]` config block
- shutdown_rx: oneshot::Receiver<u8>, // Receives the shutdown signal
- remote_addr: String, // `client.remote_addr`
- transport: Arc<T>, // Wrapper around the transport layer
- }
- // Handle of a control channel
- // Dropping it will also drop the actual control channel
- struct ControlChannelHandle {
- shutdown_tx: oneshot::Sender<u8>,
- }
- impl<T: 'static + Transport> ControlChannel<T> {
- #[instrument(skip(self), fields(service=%self.service.name))]
- async fn run(&mut self) -> Result<()> {
- let mut conn = self
- .transport
- .connect(&self.remote_addr)
- .await
- .with_context(|| format!("Failed to connect to the server: {}", &self.remote_addr))?;
- // Send hello
- let hello_send =
- Hello::ControlChannelHello(CURRENT_PROTO_VRESION, self.digest[..].try_into().unwrap());
- conn.write_all(&bincode::serialize(&hello_send).unwrap())
- .await?;
- // Read hello))
- let nonce = match read_hello(&mut conn)
- .await
- .with_context(|| "Failed to read hello from the server")?
- {
- ControlChannelHello(_, d) => d,
- _ => {
- bail!("Unexpected type of hello");
- }
- };
- // Send auth
- let mut concat = Vec::from(self.service.token.as_ref().unwrap().as_bytes());
- concat.extend_from_slice(&nonce);
- let session_key = protocol::digest(&concat);
- let auth = Auth(session_key);
- conn.write_all(&bincode::serialize(&auth).unwrap()).await?;
- // Read ack
- match read_ack(&mut conn).await? {
- Ack::Ok => {}
- v => {
- return Err(anyhow!("{}", v))
- .with_context(|| format!("Authentication failed: {}", self.service.name));
- }
- }
- // Channel ready
- info!("Control channel established");
- let remote_addr = self.remote_addr.clone();
- let local_addr = self.service.local_addr.clone();
- let data_ch_args = Arc::new(RunDataChannelArgs {
- session_key,
- remote_addr,
- local_addr,
- connector: self.transport.clone(),
- });
- loop {
- tokio::select! {
- val = read_control_cmd(&mut conn) => {
- let val = val?;
- debug!( "Received {:?}", val);
- match val {
- ControlChannelCmd::CreateDataChannel => {
- let args = data_ch_args.clone();
- tokio::spawn(async move {
- if let Err(e) = run_data_channel(args).await.with_context(|| "Failed to run the data channel") {
- error!("{:?}", e);
- }
- }.instrument(Span::current()));
- }
- }
- },
- _ = &mut self.shutdown_rx => {
- info!( "Shutting down gracefully...");
- break;
- }
- }
- }
- Ok(())
- }
- }
- impl ControlChannelHandle {
- #[instrument(skip_all, fields(service = %service.name))]
- fn new<T: 'static + Transport>(
- service: ClientServiceConfig,
- remote_addr: String,
- transport: Arc<T>,
- ) -> ControlChannelHandle {
- let digest = protocol::digest(service.name.as_bytes());
- let (shutdown_tx, shutdown_rx) = oneshot::channel();
- let mut s = ControlChannel {
- digest,
- service,
- shutdown_rx,
- remote_addr,
- transport,
- };
- tokio::spawn(
- async move {
- while let Err(err) = s
- .run()
- .await
- .with_context(|| "Failed to run the control channel")
- {
- let duration = Duration::from_secs(1);
- error!("{:?}\n\nRetry in {:?}...", err, duration);
- time::sleep(duration).await;
- }
- }
- .instrument(Span::current()),
- );
- ControlChannelHandle { shutdown_tx }
- }
- fn shutdown(self) {
- // A send failure shows that the actor has already shutdown.
- let _ = self.shutdown_tx.send(0u8);
- }
- }
|