| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444 |
- use crate::config::{ClientConfig, ClientServiceConfig, Config, TransportType};
- use crate::helper::udp_connect;
- use crate::protocol::Hello::{self, *};
- use crate::protocol::{
- self, read_ack, read_control_cmd, read_data_cmd, read_hello, Ack, Auth, ControlChannelCmd,
- DataChannelCmd, UdpTraffic, CURRENT_PROTO_VRESION, HASH_WIDTH_IN_BYTES,
- };
- use crate::transport::{NoiseTransport, TcpTransport, Transport};
- use anyhow::{anyhow, bail, Context, Result};
- use backoff::ExponentialBackoff;
- use bytes::{Bytes, BytesMut};
- use std::collections::HashMap;
- use std::net::SocketAddr;
- use std::sync::Arc;
- use tokio::io::{self, copy_bidirectional, AsyncWriteExt};
- use tokio::net::{TcpStream, UdpSocket};
- use tokio::sync::{broadcast, mpsc, oneshot, RwLock};
- use tokio::time::{self, Duration};
- use tracing::{debug, error, info, instrument, Instrument, Span};
- #[cfg(feature = "tls")]
- use crate::transport::TlsTransport;
- use crate::constants::{UDP_BUFFER_SIZE, UDP_SENDQ_SIZE, UDP_TIMEOUT};
- // The entrypoint of running a client
- pub async fn run_client(config: &Config, shutdown_rx: broadcast::Receiver<bool>) -> Result<()> {
- let config = match &config.client {
- Some(v) => v,
- None => {
- return Err(anyhow!("Try to run as a client, but the configuration is missing. Please add the `[client]` block"))
- }
- };
- match config.transport.transport_type {
- TransportType::Tcp => {
- let mut client = Client::<TcpTransport>::from(config).await?;
- client.run(shutdown_rx).await
- }
- TransportType::Tls => {
- #[cfg(feature = "tls")]
- {
- let mut client = Client::<TlsTransport>::from(config).await?;
- client.run(shutdown_rx).await
- }
- #[cfg(not(feature = "tls"))]
- crate::helper::feature_not_compile("tls")
- }
- TransportType::Noise => {
- let mut client = Client::<NoiseTransport>::from(config).await?;
- client.run(shutdown_rx).await
- }
- }
- }
- type ServiceDigest = protocol::Digest;
- type Nonce = protocol::Digest;
- // Holds the state of a client
- struct Client<'a, T: Transport> {
- config: &'a ClientConfig,
- service_handles: HashMap<String, ControlChannelHandle>,
- transport: Arc<T>,
- }
- impl<'a, T: 'static + Transport> Client<'a, T> {
- // Create a Client from `[client]` config block
- async fn from(config: &'a ClientConfig) -> Result<Client<'a, T>> {
- Ok(Client {
- config,
- service_handles: HashMap::new(),
- transport: Arc::new(
- T::new(&config.transport)
- .await
- .with_context(|| "Failed to create the transport")?,
- ),
- })
- }
- // The entrypoint of Client
- async fn run(&mut self, mut shutdown_rx: broadcast::Receiver<bool>) -> Result<()> {
- for (name, config) in &self.config.services {
- // Create a control channel for each service defined
- let handle = ControlChannelHandle::new(
- (*config).clone(),
- self.config.remote_addr.clone(),
- self.transport.clone(),
- );
- self.service_handles.insert(name.clone(), handle);
- }
- // TODO: Maybe wait for a config change signal for hot reloading
- // Wait for the shutdown signal
- loop {
- tokio::select! {
- val = shutdown_rx.recv() => {
- match val {
- Ok(_) => {}
- Err(err) => {
- error!("Unable to listen for shutdown signal: {}", err);
- }
- }
- break;
- },
- }
- }
- // Shutdown all services
- for (_, handle) in self.service_handles.drain() {
- handle.shutdown();
- }
- Ok(())
- }
- }
- struct RunDataChannelArgs<T: Transport> {
- session_key: Nonce,
- remote_addr: String,
- local_addr: String,
- connector: Arc<T>,
- }
- async fn do_data_channel_handshake<T: Transport>(
- args: Arc<RunDataChannelArgs<T>>,
- ) -> Result<T::Stream> {
- // Retry at least every 100ms, at most for 10 seconds
- let backoff = ExponentialBackoff {
- max_interval: Duration::from_millis(100),
- max_elapsed_time: Some(Duration::from_secs(10)),
- ..Default::default()
- };
- // Connect to remote_addr
- let mut conn: T::Stream = backoff::future::retry(backoff, || async {
- Ok(args
- .connector
- .connect(&args.remote_addr)
- .await
- .with_context(|| "Failed to connect to remote_addr")?)
- })
- .await?;
- // Send nonce
- let v: &[u8; HASH_WIDTH_IN_BYTES] = args.session_key[..].try_into().unwrap();
- let hello = Hello::DataChannelHello(CURRENT_PROTO_VRESION, v.to_owned());
- conn.write_all(&bincode::serialize(&hello).unwrap()).await?;
- Ok(conn)
- }
- async fn run_data_channel<T: Transport>(args: Arc<RunDataChannelArgs<T>>) -> Result<()> {
- // Do the handshake
- let mut conn = do_data_channel_handshake(args.clone()).await?;
- // Forward
- match read_data_cmd(&mut conn).await? {
- DataChannelCmd::StartForwardTcp => {
- run_data_channel_for_tcp::<T>(conn, &args.local_addr).await?;
- }
- DataChannelCmd::StartForwardUdp => {
- run_data_channel_for_udp::<T>(conn, &args.local_addr).await?;
- }
- }
- Ok(())
- }
- // Simply copying back and forth for TCP
- #[instrument(skip(conn))]
- async fn run_data_channel_for_tcp<T: Transport>(
- mut conn: T::Stream,
- local_addr: &str,
- ) -> Result<()> {
- debug!("New data channel starts forwarding");
- let mut local = TcpStream::connect(local_addr)
- .await
- .with_context(|| "Failed to conenct to local_addr")?;
- let _ = copy_bidirectional(&mut conn, &mut local).await;
- Ok(())
- }
- // Things get a little tricker when it gets to UDP because it's connectionless.
- // A UdpPortMap must be maintained for recent seen incoming address, giving them
- // each a local port, which is associated with a socket. So just the sender
- // to the socket will work fine for the map's value.
- type UdpPortMap = Arc<RwLock<HashMap<SocketAddr, mpsc::Sender<Bytes>>>>;
- #[instrument(skip(conn))]
- async fn run_data_channel_for_udp<T: Transport>(conn: T::Stream, local_addr: &str) -> Result<()> {
- debug!("New data channel starts forwarding");
- let port_map: UdpPortMap = Arc::new(RwLock::new(HashMap::new()));
- // The channel stores UdpTraffic that needs to be sent to the server
- let (outbound_tx, mut outbound_rx) = mpsc::channel::<UdpTraffic>(UDP_SENDQ_SIZE);
- // FIXME: https://github.com/tokio-rs/tls/issues/40
- // Maybe this is our concern
- let (mut rd, mut wr) = io::split(conn);
- // Keep sending items from the outbound channel to the server
- tokio::spawn(async move {
- while let Some(t) = outbound_rx.recv().await {
- debug!("outbound {:?}", t);
- if t.write(&mut wr).await.is_err() {
- break;
- }
- }
- });
- loop {
- // Read a packet from the server
- let packet = UdpTraffic::read(&mut rd).await?;
- let m = port_map.read().await;
- if m.get(&packet.from).is_none() {
- // This packet is from a address we don't see for a while,
- // which is not in the UdpPortMap.
- // So set up a mapping (and a forwarder) for it
- // Drop the reader lock
- drop(m);
- // Grab the writer lock
- // This is the only thread that will try to grab the writer lock
- // So no need to worry about some other thread has already set up
- // the mapping between the gap of dropping the reader lock and
- // grabbing the writer lock
- let mut m = port_map.write().await;
- match udp_connect(local_addr).await {
- Ok(s) => {
- let (inbound_tx, inbound_rx) = mpsc::channel(UDP_SENDQ_SIZE);
- m.insert(packet.from, inbound_tx);
- tokio::spawn(run_udp_forwarder(
- s,
- inbound_rx,
- outbound_tx.clone(),
- packet.from,
- port_map.clone(),
- ));
- }
- Err(e) => {
- error!("{:?}", e);
- }
- }
- }
- // Now there should be a udp forwarder that can receive the packet
- let m = port_map.read().await;
- if let Some(tx) = m.get(&packet.from) {
- let _ = tx.send(packet.data).await;
- }
- }
- }
- // Run a UdpSocket for the visitor `from`
- async fn run_udp_forwarder(
- s: UdpSocket,
- mut inbound_rx: mpsc::Receiver<Bytes>,
- outbount_tx: mpsc::Sender<UdpTraffic>,
- from: SocketAddr,
- port_map: UdpPortMap,
- ) -> Result<()> {
- let mut buf = BytesMut::new();
- buf.resize(UDP_BUFFER_SIZE, 0);
- loop {
- tokio::select! {
- // Receive from the server
- data = inbound_rx.recv() => {
- if let Some(data) = data {
- s.send(&data).await?;
- } else {
- break;
- }
- },
- // Receive from the service
- val = s.recv(&mut buf) => {
- let len = match val {
- Ok(v) => v,
- Err(_) => {break;}
- };
- let t = UdpTraffic{
- from,
- data: Bytes::copy_from_slice(&buf[..len])
- };
- outbount_tx.send(t).await?;
- },
- // No traffic for the duration of UDP_TIMEOUT, clean up the state
- _ = time::sleep(Duration::from_secs(UDP_TIMEOUT)) => {
- break;
- }
- }
- }
- let mut port_map = port_map.write().await;
- port_map.remove(&from);
- Ok(())
- }
- // Control channel, using T as the transport layer
- struct ControlChannel<T: Transport> {
- digest: ServiceDigest, // SHA256 of the service name
- service: ClientServiceConfig, // `[client.services.foo]` config block
- shutdown_rx: oneshot::Receiver<u8>, // Receives the shutdown signal
- remote_addr: String, // `client.remote_addr`
- transport: Arc<T>, // Wrapper around the transport layer
- }
- // Handle of a control channel
- // Dropping it will also drop the actual control channel
- struct ControlChannelHandle {
- shutdown_tx: oneshot::Sender<u8>,
- }
- impl<T: 'static + Transport> ControlChannel<T> {
- #[instrument(skip_all)]
- async fn run(&mut self) -> Result<()> {
- let mut conn = self
- .transport
- .connect(&self.remote_addr)
- .await
- .with_context(|| format!("Failed to connect to the server: {}", &self.remote_addr))?;
- // Send hello
- let hello_send =
- Hello::ControlChannelHello(CURRENT_PROTO_VRESION, self.digest[..].try_into().unwrap());
- conn.write_all(&bincode::serialize(&hello_send).unwrap())
- .await?;
- // Read hello))
- let nonce = match read_hello(&mut conn)
- .await
- .with_context(|| "Failed to read hello from the server")?
- {
- ControlChannelHello(_, d) => d,
- _ => {
- bail!("Unexpected type of hello");
- }
- };
- // Send auth
- let mut concat = Vec::from(self.service.token.as_ref().unwrap().as_bytes());
- concat.extend_from_slice(&nonce);
- let session_key = protocol::digest(&concat);
- let auth = Auth(session_key);
- conn.write_all(&bincode::serialize(&auth).unwrap()).await?;
- // Read ack
- match read_ack(&mut conn).await? {
- Ack::Ok => {}
- v => {
- return Err(anyhow!("{}", v))
- .with_context(|| format!("Authentication failed: {}", self.service.name));
- }
- }
- // Channel ready
- info!("Control channel established");
- let remote_addr = self.remote_addr.clone();
- let local_addr = self.service.local_addr.clone();
- let data_ch_args = Arc::new(RunDataChannelArgs {
- session_key,
- remote_addr,
- local_addr,
- connector: self.transport.clone(),
- });
- loop {
- tokio::select! {
- val = read_control_cmd(&mut conn) => {
- let val = val?;
- debug!( "Received {:?}", val);
- match val {
- ControlChannelCmd::CreateDataChannel => {
- let args = data_ch_args.clone();
- tokio::spawn(async move {
- if let Err(e) = run_data_channel(args).await.with_context(|| "Failed to run the data channel") {
- error!("{:?}", e);
- }
- }.instrument(Span::current()));
- }
- }
- },
- _ = &mut self.shutdown_rx => {
- info!( "Shutting down gracefully...");
- break;
- }
- }
- }
- Ok(())
- }
- }
- impl ControlChannelHandle {
- #[instrument(skip_all, fields(service = %service.name))]
- fn new<T: 'static + Transport>(
- service: ClientServiceConfig,
- remote_addr: String,
- transport: Arc<T>,
- ) -> ControlChannelHandle {
- let digest = protocol::digest(service.name.as_bytes());
- let (shutdown_tx, shutdown_rx) = oneshot::channel();
- let mut s = ControlChannel {
- digest,
- service,
- shutdown_rx,
- remote_addr,
- transport,
- };
- tokio::spawn(
- async move {
- while let Err(err) = s
- .run()
- .await
- .with_context(|| "Failed to run the control channel")
- {
- let duration = Duration::from_secs(1);
- error!("{:?}\n\nRetry in {:?}...", err, duration);
- time::sleep(duration).await;
- }
- }
- .instrument(Span::current()),
- );
- ControlChannelHandle { shutdown_tx }
- }
- fn shutdown(self) {
- // A send failure shows that the actor has already shutdown.
- let _ = self.shutdown_tx.send(0u8);
- }
- }
|