use crate::config::{ClientConfig, ClientServiceConfig, Config, TransportType}; use crate::helper::udp_connect; use crate::protocol::Hello::{self, *}; use crate::protocol::{ self, read_ack, read_control_cmd, read_data_cmd, read_hello, Ack, Auth, ControlChannelCmd, DataChannelCmd, UdpTraffic, CURRENT_PROTO_VRESION, HASH_WIDTH_IN_BYTES, }; use crate::transport::{NoiseTransport, TcpTransport, Transport}; use anyhow::{anyhow, bail, Context, Result}; use backoff::ExponentialBackoff; use bytes::{Bytes, BytesMut}; use std::collections::HashMap; use std::net::SocketAddr; use std::sync::Arc; use tokio::io::{self, copy_bidirectional, AsyncWriteExt}; use tokio::net::{TcpStream, UdpSocket}; use tokio::sync::{broadcast, mpsc, oneshot, RwLock}; use tokio::time::{self, Duration}; use tracing::{debug, error, info, instrument, Instrument, Span}; #[cfg(feature = "tls")] use crate::transport::TlsTransport; use crate::constants::{UDP_BUFFER_SIZE, UDP_SENDQ_SIZE, UDP_TIMEOUT}; // The entrypoint of running a client pub async fn run_client(config: &Config, shutdown_rx: broadcast::Receiver) -> Result<()> { let config = match &config.client { Some(v) => v, None => { return Err(anyhow!("Try to run as a client, but the configuration is missing. Please add the `[client]` block")) } }; match config.transport.transport_type { TransportType::Tcp => { let mut client = Client::::from(config).await?; client.run(shutdown_rx).await } TransportType::Tls => { #[cfg(feature = "tls")] { let mut client = Client::::from(config).await?; client.run(shutdown_rx).await } #[cfg(not(feature = "tls"))] crate::helper::feature_not_compile("tls") } TransportType::Noise => { let mut client = Client::::from(config).await?; client.run(shutdown_rx).await } } } type ServiceDigest = protocol::Digest; type Nonce = protocol::Digest; // Holds the state of a client struct Client<'a, T: Transport> { config: &'a ClientConfig, service_handles: HashMap, transport: Arc, } impl<'a, T: 'static + Transport> Client<'a, T> { // Create a Client from `[client]` config block async fn from(config: &'a ClientConfig) -> Result> { Ok(Client { config, service_handles: HashMap::new(), transport: Arc::new( T::new(&config.transport) .await .with_context(|| "Failed to create the transport")?, ), }) } // The entrypoint of Client async fn run(&mut self, mut shutdown_rx: broadcast::Receiver) -> Result<()> { for (name, config) in &self.config.services { // Create a control channel for each service defined let handle = ControlChannelHandle::new( (*config).clone(), self.config.remote_addr.clone(), self.transport.clone(), ); self.service_handles.insert(name.clone(), handle); } // TODO: Maybe wait for a config change signal for hot reloading // Wait for the shutdown signal loop { tokio::select! { val = shutdown_rx.recv() => { match val { Ok(_) => {} Err(err) => { error!("Unable to listen for shutdown signal: {}", err); } } break; }, } } // Shutdown all services for (_, handle) in self.service_handles.drain() { handle.shutdown(); } Ok(()) } } struct RunDataChannelArgs { session_key: Nonce, remote_addr: String, local_addr: String, connector: Arc, } async fn do_data_channel_handshake( args: Arc>, ) -> Result { // Retry at least every 100ms, at most for 10 seconds let backoff = ExponentialBackoff { max_interval: Duration::from_millis(100), max_elapsed_time: Some(Duration::from_secs(10)), ..Default::default() }; // Connect to remote_addr let mut conn: T::Stream = backoff::future::retry(backoff, || async { Ok(args .connector .connect(&args.remote_addr) .await .with_context(|| "Failed to connect to remote_addr")?) }) .await?; // Send nonce let v: &[u8; HASH_WIDTH_IN_BYTES] = args.session_key[..].try_into().unwrap(); let hello = Hello::DataChannelHello(CURRENT_PROTO_VRESION, v.to_owned()); conn.write_all(&bincode::serialize(&hello).unwrap()).await?; Ok(conn) } async fn run_data_channel(args: Arc>) -> Result<()> { // Do the handshake let mut conn = do_data_channel_handshake(args.clone()).await?; // Forward match read_data_cmd(&mut conn).await? { DataChannelCmd::StartForwardTcp => { run_data_channel_for_tcp::(conn, &args.local_addr).await?; } DataChannelCmd::StartForwardUdp => { run_data_channel_for_udp::(conn, &args.local_addr).await?; } } Ok(()) } // Simply copying back and forth for TCP #[instrument(skip(conn))] async fn run_data_channel_for_tcp( mut conn: T::Stream, local_addr: &str, ) -> Result<()> { debug!("New data channel starts forwarding"); let mut local = TcpStream::connect(local_addr) .await .with_context(|| "Failed to conenct to local_addr")?; let _ = copy_bidirectional(&mut conn, &mut local).await; Ok(()) } // Things get a little tricker when it gets to UDP because it's connectionless. // A UdpPortMap must be maintained for recent seen incoming address, giving them // each a local port, which is associated with a socket. So just the sender // to the socket will work fine for the map's value. type UdpPortMap = Arc>>>; #[instrument(skip(conn))] async fn run_data_channel_for_udp(conn: T::Stream, local_addr: &str) -> Result<()> { debug!("New data channel starts forwarding"); let port_map: UdpPortMap = Arc::new(RwLock::new(HashMap::new())); // The channel stores UdpTraffic that needs to be sent to the server let (outbound_tx, mut outbound_rx) = mpsc::channel::(UDP_SENDQ_SIZE); // FIXME: https://github.com/tokio-rs/tls/issues/40 // Maybe this is our concern let (mut rd, mut wr) = io::split(conn); // Keep sending items from the outbound channel to the server tokio::spawn(async move { while let Some(t) = outbound_rx.recv().await { debug!("outbound {:?}", t); if t.write(&mut wr).await.is_err() { break; } } }); loop { // Read a packet from the server let packet = UdpTraffic::read(&mut rd).await?; let m = port_map.read().await; if m.get(&packet.from).is_none() { // This packet is from a address we don't see for a while, // which is not in the UdpPortMap. // So set up a mapping (and a forwarder) for it // Drop the reader lock drop(m); // Grab the writer lock // This is the only thread that will try to grab the writer lock // So no need to worry about some other thread has already set up // the mapping between the gap of dropping the reader lock and // grabbing the writer lock let mut m = port_map.write().await; match udp_connect(local_addr).await { Ok(s) => { let (inbound_tx, inbound_rx) = mpsc::channel(UDP_SENDQ_SIZE); m.insert(packet.from, inbound_tx); tokio::spawn(run_udp_forwarder( s, inbound_rx, outbound_tx.clone(), packet.from, port_map.clone(), )); } Err(e) => { error!("{:?}", e); } } } // Now there should be a udp forwarder that can receive the packet let m = port_map.read().await; if let Some(tx) = m.get(&packet.from) { let _ = tx.send(packet.data).await; } } } // Run a UdpSocket for the visitor `from` async fn run_udp_forwarder( s: UdpSocket, mut inbound_rx: mpsc::Receiver, outbount_tx: mpsc::Sender, from: SocketAddr, port_map: UdpPortMap, ) -> Result<()> { let mut buf = BytesMut::new(); buf.resize(UDP_BUFFER_SIZE, 0); loop { tokio::select! { // Receive from the server data = inbound_rx.recv() => { if let Some(data) = data { s.send(&data).await?; } else { break; } }, // Receive from the service val = s.recv(&mut buf) => { let len = match val { Ok(v) => v, Err(_) => {break;} }; let t = UdpTraffic{ from, data: Bytes::copy_from_slice(&buf[..len]) }; outbount_tx.send(t).await?; }, // No traffic for the duration of UDP_TIMEOUT, clean up the state _ = time::sleep(Duration::from_secs(UDP_TIMEOUT)) => { break; } } } let mut port_map = port_map.write().await; port_map.remove(&from); Ok(()) } // Control channel, using T as the transport layer struct ControlChannel { digest: ServiceDigest, // SHA256 of the service name service: ClientServiceConfig, // `[client.services.foo]` config block shutdown_rx: oneshot::Receiver, // Receives the shutdown signal remote_addr: String, // `client.remote_addr` transport: Arc, // Wrapper around the transport layer } // Handle of a control channel // Dropping it will also drop the actual control channel struct ControlChannelHandle { shutdown_tx: oneshot::Sender, } impl ControlChannel { #[instrument(skip_all)] async fn run(&mut self) -> Result<()> { let mut conn = self .transport .connect(&self.remote_addr) .await .with_context(|| format!("Failed to connect to the server: {}", &self.remote_addr))?; // Send hello let hello_send = Hello::ControlChannelHello(CURRENT_PROTO_VRESION, self.digest[..].try_into().unwrap()); conn.write_all(&bincode::serialize(&hello_send).unwrap()) .await?; // Read hello)) let nonce = match read_hello(&mut conn) .await .with_context(|| "Failed to read hello from the server")? { ControlChannelHello(_, d) => d, _ => { bail!("Unexpected type of hello"); } }; // Send auth let mut concat = Vec::from(self.service.token.as_ref().unwrap().as_bytes()); concat.extend_from_slice(&nonce); let session_key = protocol::digest(&concat); let auth = Auth(session_key); conn.write_all(&bincode::serialize(&auth).unwrap()).await?; // Read ack match read_ack(&mut conn).await? { Ack::Ok => {} v => { return Err(anyhow!("{}", v)) .with_context(|| format!("Authentication failed: {}", self.service.name)); } } // Channel ready info!("Control channel established"); let remote_addr = self.remote_addr.clone(); let local_addr = self.service.local_addr.clone(); let data_ch_args = Arc::new(RunDataChannelArgs { session_key, remote_addr, local_addr, connector: self.transport.clone(), }); loop { tokio::select! { val = read_control_cmd(&mut conn) => { let val = val?; debug!( "Received {:?}", val); match val { ControlChannelCmd::CreateDataChannel => { let args = data_ch_args.clone(); tokio::spawn(async move { if let Err(e) = run_data_channel(args).await.with_context(|| "Failed to run the data channel") { error!("{:?}", e); } }.instrument(Span::current())); } } }, _ = &mut self.shutdown_rx => { info!( "Shutting down gracefully..."); break; } } } Ok(()) } } impl ControlChannelHandle { #[instrument(skip_all, fields(service = %service.name))] fn new( service: ClientServiceConfig, remote_addr: String, transport: Arc, ) -> ControlChannelHandle { let digest = protocol::digest(service.name.as_bytes()); let (shutdown_tx, shutdown_rx) = oneshot::channel(); let mut s = ControlChannel { digest, service, shutdown_rx, remote_addr, transport, }; tokio::spawn( async move { while let Err(err) = s .run() .await .with_context(|| "Failed to run the control channel") { let duration = Duration::from_secs(1); error!("{:?}\n\nRetry in {:?}...", err, duration); time::sleep(duration).await; } } .instrument(Span::current()), ); ControlChannelHandle { shutdown_tx } } fn shutdown(self) { // A send failure shows that the actor has already shutdown. let _ = self.shutdown_tx.send(0u8); } }