瀏覽代碼

feat: TLS support

Yujia Qiao 4 年之前
父節點
當前提交
dcef7f2d0f
共有 15 個文件被更改,包括 668 次插入156 次删除
  1. 173 0
      Cargo.lock
  2. 2 0
      Cargo.toml
  3. 12 1
      README.md
  4. 35 0
      example/tls/ca-cert.pem
  5. 12 0
      example/tls/client.toml
  6. 二進制
      example/tls/identity.pfx
  7. 13 0
      example/tls/server.toml
  8. 62 32
      src/client.rs
  9. 95 28
      src/config.rs
  10. 1 0
      src/lib.rs
  11. 11 10
      src/protocol.rs
  12. 90 85
      src/server.rs
  13. 25 0
      src/transport/mod.rs
  14. 41 0
      src/transport/tcp.rs
  15. 96 0
      src/transport/tls.rs

+ 173 - 0
Cargo.lock

@@ -17,6 +17,17 @@ version = "1.0.51"
 source = "registry+https://github.com/rust-lang/crates.io-index"
 checksum = "8b26702f315f53b6071259e15dd9d64528213b44d61de1ec926eca7715d62203"
 
+[[package]]
+name = "async-trait"
+version = "0.1.52"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "061a7acccaa286c011ddc30970520b98fa40e00c9d644633fb26b5fc63a265e3"
+dependencies = [
+ "proc-macro2",
+ "quote",
+ "syn",
+]
+
 [[package]]
 name = "atty"
 version = "0.2.14"
@@ -130,6 +141,22 @@ dependencies = [
  "syn",
 ]
 
+[[package]]
+name = "core-foundation"
+version = "0.9.2"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "6888e10551bb93e424d8df1d07f1a8b4fceb0001a3a4b048bfc47554946f47b3"
+dependencies = [
+ "core-foundation-sys",
+ "libc",
+]
+
+[[package]]
+name = "core-foundation-sys"
+version = "0.8.3"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "5827cebf4670468b8772dd191856768aedcb1b0278a04f989f7766351917b9dc"
+
 [[package]]
 name = "fdlimit"
 version = "0.2.1"
@@ -139,6 +166,21 @@ dependencies = [
  "libc",
 ]
 
+[[package]]
+name = "foreign-types"
+version = "0.3.2"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "f6f339eb8adc052cd2ca78910fda869aefa38d22d5cb648e6485e4d3fc06f3b1"
+dependencies = [
+ "foreign-types-shared",
+]
+
+[[package]]
+name = "foreign-types-shared"
+version = "0.1.1"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "00b0228411908ca8685dba7fc2cdd70ec9990a6e753e89b6ac91a84c40fbaf4b"
+
 [[package]]
 name = "futures-core"
 version = "0.3.18"
@@ -287,6 +329,24 @@ dependencies = [
  "winapi",
 ]
 
+[[package]]
+name = "native-tls"
+version = "0.2.8"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "48ba9f7719b5a0f42f338907614285fb5fd70e53858141f69898a1fb7203b24d"
+dependencies = [
+ "lazy_static",
+ "libc",
+ "log",
+ "openssl",
+ "openssl-probe",
+ "openssl-sys",
+ "schannel",
+ "security-framework",
+ "security-framework-sys",
+ "tempfile",
+]
+
 [[package]]
 name = "ntapi"
 version = "0.3.6"
@@ -331,6 +391,39 @@ version = "1.8.0"
 source = "registry+https://github.com/rust-lang/crates.io-index"
 checksum = "692fcb63b64b1758029e0a96ee63e049ce8c5948587f2f7208df04625e5f6b56"
 
+[[package]]
+name = "openssl"
+version = "0.10.38"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "0c7ae222234c30df141154f159066c5093ff73b63204dcda7121eb082fc56a95"
+dependencies = [
+ "bitflags",
+ "cfg-if",
+ "foreign-types",
+ "libc",
+ "once_cell",
+ "openssl-sys",
+]
+
+[[package]]
+name = "openssl-probe"
+version = "0.1.4"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "28988d872ab76095a6e6ac88d99b54fd267702734fd7ffe610ca27f533ddb95a"
+
+[[package]]
+name = "openssl-sys"
+version = "0.9.72"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "7e46109c383602735fa0a2e48dd2b7c892b048e1bf69e5c3b1d804b7d9c203cb"
+dependencies = [
+ "autocfg",
+ "cc",
+ "libc",
+ "pkg-config",
+ "vcpkg",
+]
+
 [[package]]
 name = "os_str_bytes"
 version = "4.2.0"
@@ -391,6 +484,12 @@ version = "0.2.7"
 source = "registry+https://github.com/rust-lang/crates.io-index"
 checksum = "8d31d11c69a6b52a174b42bdc0c30e5e11670f90788b2c471c31c1d17d449443"
 
+[[package]]
+name = "pkg-config"
+version = "0.3.24"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "58893f751c9b0412871a09abd62ecd2a00298c6c83befa223ef98c52aef40cbe"
+
 [[package]]
 name = "ppv-lite86"
 version = "0.2.15"
@@ -484,6 +583,7 @@ name = "rathole"
 version = "0.1.0"
 dependencies = [
  "anyhow",
+ "async-trait",
  "backoff",
  "bincode",
  "bytes",
@@ -496,6 +596,7 @@ dependencies = [
  "serde",
  "socket2",
  "tokio",
+ "tokio-native-tls",
  "toml",
  "tracing",
  "tracing-subscriber",
@@ -534,6 +635,15 @@ version = "0.6.25"
 source = "registry+https://github.com/rust-lang/crates.io-index"
 checksum = "f497285884f3fcff424ffc933e56d7cbca511def0c9831a7f9b5f6153e3cc89b"
 
+[[package]]
+name = "remove_dir_all"
+version = "0.5.3"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "3acd125665422973a33ac9d3dd2df85edad0f4ae9b00dafb1a05e43a9f5ef8e7"
+dependencies = [
+ "winapi",
+]
+
 [[package]]
 name = "ring"
 version = "0.16.20"
@@ -555,12 +665,45 @@ version = "1.0.8"
 source = "registry+https://github.com/rust-lang/crates.io-index"
 checksum = "b30e4c09749c107e83dd61baf9604198efc4542863c88af39dafcaca89c7c9f9"
 
+[[package]]
+name = "schannel"
+version = "0.1.19"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "8f05ba609c234e60bee0d547fe94a4c7e9da733d1c962cf6e59efa4cd9c8bc75"
+dependencies = [
+ "lazy_static",
+ "winapi",
+]
+
 [[package]]
 name = "scopeguard"
 version = "1.1.0"
 source = "registry+https://github.com/rust-lang/crates.io-index"
 checksum = "d29ab0c6d3fc0ee92fe66e2d99f700eab17a8d57d1c1d3b748380fb20baa78cd"
 
+[[package]]
+name = "security-framework"
+version = "2.4.2"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "525bc1abfda2e1998d152c45cf13e696f76d0a4972310b22fac1658b05df7c87"
+dependencies = [
+ "bitflags",
+ "core-foundation",
+ "core-foundation-sys",
+ "libc",
+ "security-framework-sys",
+]
+
+[[package]]
+name = "security-framework-sys"
+version = "2.4.2"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "a9dd14d83160b528b7bfd66439110573efcfbe281b17fc2ca9f39f550d619c7e"
+dependencies = [
+ "core-foundation-sys",
+ "libc",
+]
+
 [[package]]
 name = "serde"
 version = "1.0.130"
@@ -649,6 +792,20 @@ dependencies = [
  "unicode-xid",
 ]
 
+[[package]]
+name = "tempfile"
+version = "3.2.0"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "dac1c663cfc93810f88aed9b8941d48cabf856a1b111c29a40439018d870eb22"
+dependencies = [
+ "cfg-if",
+ "libc",
+ "rand",
+ "redox_syscall",
+ "remove_dir_all",
+ "winapi",
+]
+
 [[package]]
 name = "termcolor"
 version = "1.1.2"
@@ -707,6 +864,16 @@ dependencies = [
  "syn",
 ]
 
+[[package]]
+name = "tokio-native-tls"
+version = "0.3.0"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "f7d995660bd2b7f8c1568414c1126076c13fbb725c40112dc0120b78eb9b717b"
+dependencies = [
+ "native-tls",
+ "tokio",
+]
+
 [[package]]
 name = "toml"
 version = "0.5.8"
@@ -824,6 +991,12 @@ version = "0.7.1"
 source = "registry+https://github.com/rust-lang/crates.io-index"
 checksum = "a156c684c91ea7d62626509bce3cb4e1d9ed5c4d978f7b4352658f96a4c26b4a"
 
+[[package]]
+name = "vcpkg"
+version = "0.2.15"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "accd4ea62f7bb7a82fe23066fb0957d48ef677f6eeb8215f372f52e48bb32426"
+
 [[package]]
 name = "version_check"
 version = "0.9.3"

+ 2 - 0
Cargo.toml

@@ -26,3 +26,5 @@ tracing = "0.1"
 tracing-subscriber = "0.2"
 socket2 = "0.4"
 fdlimit = "0.2.1"
+tokio-native-tls = "0.3.0"
+async-trait = "0.1.52"

+ 12 - 1
README.md

@@ -61,6 +61,12 @@ Here is the full configuration specification:
 remote_addr = "example.com:2333" # Necessary. The address of the server
 default_token = "default_token_if_not_specify" # Optional. The default token of services, if they don't define their own ones
 
+[client.transport]
+type = "tcp" # Necessary if multiple transport blocks present. Possibile values: ["tcp", "tls"]. Default: "tcp"
+[client.transport.tls] # Necessary if `type` is "tls"
+trusted_root = "ca.pem" # Necessary. The certificate of CA that signed the server's certificate
+hostname = "example.com" # Optional. The hostname that the client uses to validate the certificate. If not set, fallback to `client.remote_addr`
+
 [client.services.service1] # A service that needs forwarding. The name `service1` can change arbitrarily, as long as identical to the name in the server's configuration
 token = "whatever" # Necessary if `client.default_token` not set
 local_addr = "127.0.0.1:1081" # Necessary. The address of the service that needs to be forwarded
@@ -72,6 +78,12 @@ local_addr = "127.0.0.1:1082"
 bind_addr = "0.0.0.0:2333" # Necessary. The address that the server listens for clients. Generally only the port needs to be change. 
 default_token = "default_token_if_not_specify" # Optional
 
+[server.transport]
+type = "tcp" # Same as `[client.transport]`
+[server.transport.tls]
+pkcs12 = "identify.pfx" # Necessary. pkcs12 file of server's certificate and private key
+pkcs12_password = "password" # Necessary. Password of the pkcs12 file
+
 [server.services.service1] # The service name must be identical to the client side
 token = "whatever" # Necesary if `server.default_token` not set
 bind_addr = "0.0.0.0:8081" # Necessary. The address of the service is exposed at. Generally only the port needs to be change. 
@@ -95,6 +107,5 @@ See also [Benchmark](./doc/benchmark.md).
 `rathole` is in active development. A load of features is on the way:
 
 - [ ] UDP support
-- [ ] TLS transport
 - [ ] Hot reloading
 - [ ] HTTP APIs for configuration

+ 35 - 0
example/tls/ca-cert.pem

@@ -0,0 +1,35 @@
+-----BEGIN CERTIFICATE-----
+MIIGKzCCBBOgAwIBAgIUaFOLSj0B/GBULJgIHBXIFCz6Fj0wDQYJKoZIhvcNAQEL
+BQAwgaQxCzAJBgNVBAYTAkZSMRIwEAYDVQQIDAlPY2NpdGFuaWUxETAPBgNVBAcM
+CFRvdWxvdXNlMRQwEgYDVQQKDAtUZWNoIFNjaG9vbDESMBAGA1UECwwJRWR1Y2F0
+aW9uMRowGAYDVQQDDBEqLnRlY2hzY2hvb2wuZ3VydTEoMCYGCSqGSIb3DQEJARYZ
+dGVjaHNjaG9vbC5ndXJ1QGdtYWlsLmNvbTAeFw0yMTEyMTYxMDExMjhaFw0yMjEy
+MTYxMDExMjhaMIGkMQswCQYDVQQGEwJGUjESMBAGA1UECAwJT2NjaXRhbmllMREw
+DwYDVQQHDAhUb3Vsb3VzZTEUMBIGA1UECgwLVGVjaCBTY2hvb2wxEjAQBgNVBAsM
+CUVkdWNhdGlvbjEaMBgGA1UEAwwRKi50ZWNoc2Nob29sLmd1cnUxKDAmBgkqhkiG
+9w0BCQEWGXRlY2hzY2hvb2wuZ3VydUBnbWFpbC5jb20wggIiMA0GCSqGSIb3DQEB
+AQUAA4ICDwAwggIKAoICAQC0T91ZjFcuptmtsZjHbfg72G50lPaiAZtvWBRv7Zms
+qgJVKyATmQ3eak8q9R69jffv3MBl2vuY0rSUyG0MYTH7vpk5XrllFjZruhRb9Roo
+wdDNAbmuB6ogP0wAgTVRSs3RYxyADOyXr8jnngPYLw5HWsBCJgmW+ARLZD0c2wvu
+GJ8Boxxy/tY4WQ18UIA81cjTorazKKRokRjFrx6AVlglJjyZt3lXHdq2nM1fsAjl
+Hhr5g0JCgV40bfwXHRc0E7lMe4U1LNYVy6Mmfv0d3xIsiwnGuhggAYMSsebneVZ7
+PoPIeqOdnm3sgh/18EdbhxOd+RV/AwZfWG35dsm+xrsqeAg4qSd2hvdlj8kba4dT
+Uvj6gpRirQXL0ZdPEsh9iG3xvxuwp/S7NKiuXm4TUiUWCCvtHyAUR/IpDkgswjNF
+cxPDwPGq1xWIqjjz9A8ucGqcWmCDejdsnazG7m62ZrnyJcRehf6TwtsrEjIFTigB
+a6ynWEGXqfyX4D+MtC9Y7++OksjCTR9m17SFAQh01NSBix1YaY/u+com1MGjqQTp
+bZGDmIdEtuV66+/pCl65Qt5B9M8JqgxSozvNqELl6RrW7iVlkL+Yi/fGd3Ms02Nm
+vmjSGwbmPVYqmwCysXdQbVbKMABilbzIUbXd3mjlJvrWiXdBEazflMk7g/Vh7Ont
+5wIDAQABo1MwUTAdBgNVHQ4EFgQUuyCI+Jcp9EtA+c0PubGLiX7JXbcwHwYDVR0j
+BBgwFoAUuyCI+Jcp9EtA+c0PubGLiX7JXbcwDwYDVR0TAQH/BAUwAwEB/zANBgkq
+hkiG9w0BAQsFAAOCAgEAJwpRh5UVWKSgahONtDgbAWX2rd/UQKUMhNkLhKVNpMoF
+xUvGcIDYrwlv8kvSk+h7sPooc/MhzZjGJByQug0kDSeh8Zy2z6L64DON4TiUYLNc
+ackYLjZGA/NKgzyIExljMEbuPzrb0sco7Xw1mbDPTiXWsw7WHvGID13daxrmOwjx
+s8M6+4rmYTv5bqO87vecGbOjqH33gdWIkVYPYapeZQYT6ExPYOfQnkNa1nFjI8Jy
+all0do2lVdUS+csLi/ONm1aI/1acgPMrTxy5nEqqinLv9KPNgXqDbuSUKIlu6+mP
+dj8Fl5kmHT0Us8XLTkQAhz0pBiKCI6L/njlDDfB2VmU+Mh1Oyer+wU4JTvJuVtZe
+9K1sccbxHT16stPF/jkqYYc6TlLCjRRvhD+NfZKpugaEEmP0uRx2vBH0N/z+u6fJ
+pA7NtQP0/fZb2VULQYvycvVIEDo98yQffkviHlUWGKFALNvVOYrtNkhuq5Jf1JrV
+s0SwRKF9oS27/CDsaYr+EBzBiu++sZUghAL1+tqPy8WkJIaCbxjqDPB6WW6zjv30
+A8wSe7dfj3zlmYsgVaYJh2GzypwVFvldn11m7TkKdmGHD+3eESGLucZN9r9nmaG3
+ARy8Vnhqtqt12mvW5t9cJvBk0HZG13IOpS1ErBkm5vDll1y30QrdPtFBeY9Ph0Y=
+-----END CERTIFICATE-----

+ 12 - 0
example/tls/client.toml

@@ -0,0 +1,12 @@
+[client]
+remote_addr = "localhost:2333"
+default_token = "123"
+
+[client.transport]
+type = "tls"
+[client.transport.tls]
+trusted_root = "example/tls/ca-cert.pem"
+hostname = "0.0.0.0"
+
+[client.services.foo1]
+local_addr = "127.0.0.1:80"

二進制
example/tls/identity.pfx


+ 13 - 0
example/tls/server.toml

@@ -0,0 +1,13 @@
+[server]
+bind_addr = "0.0.0.0:2333"
+default_token = "123"
+
+[server.transport]
+type = "tls"
+[server.transport.tls]
+pkcs12 = "example/tls/identity.pfx"
+pkcs12_password = "1234"
+
+[server.services.foo1]
+bind_addr = "0.0.0.0:5202"
+

+ 62 - 32
src/client.rs

@@ -1,50 +1,68 @@
 use std::collections::HashMap;
 use std::sync::Arc;
 
-use crate::config::{ClientConfig, ClientServiceConfig, Config};
+use crate::config::{ClientConfig, ClientServiceConfig, Config, TransportType};
 use crate::protocol::{
-    self, read_hello, DataChannelCmd,
+    self, Ack, Auth, ControlChannelCmd, DataChannelCmd,
     Hello::{self, *},
     CURRENT_PROTO_VRESION, HASH_WIDTH_IN_BYTES,
 };
-use crate::protocol::{read_data_cmd, Ack, Auth, ControlChannelCmd};
+use crate::protocol::{read_ack, read_control_cmd, read_data_cmd, read_hello};
+use crate::transport::{TcpTransport, TlsTransport, Transport};
 use anyhow::{anyhow, bail, Context, Result};
 use backoff::ExponentialBackoff;
-use tokio::io;
+
+use tokio::io::{copy_bidirectional, AsyncWriteExt};
+use tokio::net::TcpStream;
 use tokio::sync::oneshot;
 use tokio::time::{self, Duration};
-use tokio::{self, io::AsyncWriteExt, net::TcpStream};
 use tracing::{debug, error, info, instrument, Instrument, Span};
 
 pub async fn run_client(config: &Config) -> Result<()> {
-    let mut client = Client::from(config)?;
-    client.run().await
+    let config = match &config.client {
+        Some(v) => v,
+        None => {
+            return Err(anyhow!("Try to run as a client, but the configuration is missing. Please add the `[client]` block"))
+        }
+    };
+
+    match config.transport.transport_type {
+        TransportType::Tcp => {
+            let mut client = Client::<TcpTransport>::from(&config).await?;
+            client.run().await
+        }
+        TransportType::Tls => {
+            let mut client = Client::<TlsTransport>::from(&config).await?;
+            client.run().await
+        }
+    }
 }
 
 type ServiceDigest = protocol::Digest;
 type Nonce = protocol::Digest;
 
-struct Client<'a> {
+struct Client<'a, T: Transport> {
     config: &'a ClientConfig,
     service_handles: HashMap<String, ControlChannelHandle>,
+    transport: Arc<T>,
 }
 
-impl<'a> Client<'a> {
-    fn from(config: &'a Config) -> Result<Client> {
-        if let Some(config) = &config.client {
-            Ok(Client {
-                config,
-                service_handles: HashMap::new(),
-            })
-        } else {
-            Err(anyhow!("Try to run as a client, but the configuration is missing. Please add the `[client]` block"))
-        }
+impl<'a, T: 'static + Transport> Client<'a, T> {
+    async fn from(config: &'a ClientConfig) -> Result<Client<'a, T>> {
+        Ok(Client {
+            config,
+            service_handles: HashMap::new(),
+            transport: Arc::new(*T::new(&config.transport).await?),
+        })
     }
 
     async fn run(&mut self) -> Result<()> {
         for (name, config) in &self.config.services {
-            let handle =
-                ControlChannelHandle::new((*config).clone(), self.config.remote_addr.clone());
+            let handle = ControlChannelHandle::new(
+                (*config).clone(),
+                self.config.remote_addr.clone(),
+                self.transport.clone(),
+            );
             self.service_handles.insert(name.clone(), handle);
         }
 
@@ -71,13 +89,14 @@ impl<'a> Client<'a> {
     }
 }
 
-struct RunDataChannelArgs {
+struct RunDataChannelArgs<T: Transport> {
     session_key: Nonce,
     remote_addr: String,
     local_addr: String,
+    connector: Arc<T>,
 }
 
-async fn run_data_channel(args: Arc<RunDataChannelArgs>) -> Result<()> {
+async fn run_data_channel<T: Transport>(args: Arc<RunDataChannelArgs<T>>) -> Result<()> {
     // Retry at least every 100ms, at most for 10 seconds
     let backoff = ExponentialBackoff {
         max_interval: Duration::from_millis(100),
@@ -86,8 +105,10 @@ async fn run_data_channel(args: Arc<RunDataChannelArgs>) -> Result<()> {
     };
 
     // Connect to remote_addr
-    let mut conn = backoff::future::retry(backoff, || async {
-        Ok(TcpStream::connect(&args.remote_addr)
+    let mut conn: T::Stream = backoff::future::retry(backoff, || async {
+        Ok(args
+            .connector
+            .connect(&args.remote_addr)
             .await
             .with_context(|| "Failed to connect to remote_addr")?)
     })
@@ -104,27 +125,30 @@ async fn run_data_channel(args: Arc<RunDataChannelArgs>) -> Result<()> {
             let mut local = TcpStream::connect(&args.local_addr)
                 .await
                 .with_context(|| "Failed to conenct to local_addr")?;
-            let _ = io::copy_bidirectional(&mut conn, &mut local).await;
+            let _ = copy_bidirectional(&mut conn, &mut local).await;
         }
     }
     Ok(())
 }
 
-struct ControlChannel {
+struct ControlChannel<T: Transport> {
     digest: ServiceDigest,
     service: ClientServiceConfig,
     shutdown_rx: oneshot::Receiver<u8>,
     remote_addr: String,
+    transport: Arc<T>,
 }
 
 struct ControlChannelHandle {
     shutdown_tx: oneshot::Sender<u8>,
 }
 
-impl ControlChannel {
+impl<T: 'static + Transport> ControlChannel<T> {
     #[instrument(skip(self), fields(service=%self.service.name))]
     async fn run(&mut self) -> Result<()> {
-        let mut conn = TcpStream::connect(&self.remote_addr)
+        let mut conn = self
+            .transport
+            .connect(&self.remote_addr)
             .await
             .with_context(|| format!("Failed to connect to the server: {}", &self.remote_addr))?;
 
@@ -134,7 +158,7 @@ impl ControlChannel {
         conn.write_all(&bincode::serialize(&hello_send).unwrap())
             .await?;
 
-        // Read hello
+        // Read hello))
         let nonce = match read_hello(&mut conn)
             .await
             .with_context(|| "Failed to read hello from the server")?
@@ -154,7 +178,7 @@ impl ControlChannel {
         conn.write_all(&bincode::serialize(&auth).unwrap()).await?;
 
         // Read ack
-        match protocol::read_ack(&mut conn).await? {
+        match read_ack(&mut conn).await? {
             Ack::Ok => {}
             v => {
                 return Err(anyhow!("{}", v))
@@ -171,11 +195,12 @@ impl ControlChannel {
             session_key,
             remote_addr,
             local_addr,
+            connector: self.transport.clone(),
         });
 
         loop {
             tokio::select! {
-                val = protocol::read_control_cmd(&mut conn) => {
+                val = read_control_cmd(&mut conn) => {
                     let val = val?;
                     debug!( "Received {:?}", val);
                     match val {
@@ -202,7 +227,11 @@ impl ControlChannel {
 
 impl ControlChannelHandle {
     #[instrument(skip_all, fields(service = %service.name))]
-    fn new(service: ClientServiceConfig, remote_addr: String) -> ControlChannelHandle {
+    fn new<T: 'static + Transport>(
+        service: ClientServiceConfig,
+        remote_addr: String,
+        transport: Arc<T>,
+    ) -> ControlChannelHandle {
         let digest = protocol::digest(service.name.as_bytes());
         let (shutdown_tx, shutdown_rx) = oneshot::channel();
         let mut s = ControlChannel {
@@ -210,6 +239,7 @@ impl ControlChannelHandle {
             service,
             shutdown_rx,
             remote_addr,
+            transport,
         };
 
         tokio::spawn(

+ 95 - 28
src/config.rs

@@ -6,15 +6,17 @@ use tokio::fs;
 use toml;
 
 #[derive(Debug, Serialize, Deserialize, Copy, Clone)]
-pub enum Encryption {
-    #[serde(rename = "none")]
-    None,
-    #[serde(rename = "aes")]
-    Aes,
+pub enum TransportType {
+    #[serde(rename = "tcp")]
+    Tcp,
+    #[serde(rename = "tls")]
+    Tls,
 }
 
-fn default_encryption() -> Encryption {
-    Encryption::None
+impl Default for TransportType {
+    fn default() -> TransportType {
+        TransportType::Tcp
+    }
 }
 
 #[derive(Debug, Serialize, Deserialize, Clone)]
@@ -23,8 +25,6 @@ pub struct ClientServiceConfig {
     pub name: String,
     pub local_addr: String,
     pub token: Option<String>,
-    #[serde(default = "default_encryption")]
-    pub encryption: Encryption,
 }
 
 #[derive(Debug, Serialize, Deserialize, Clone)]
@@ -33,8 +33,25 @@ pub struct ServerServiceConfig {
     pub name: String,
     pub bind_addr: String,
     pub token: Option<String>,
-    #[serde(default = "default_encryption")]
-    pub encryption: Encryption,
+}
+
+#[derive(Clone, Debug, Serialize, Deserialize)]
+pub struct TlsConfig {
+    pub hostname: Option<String>,
+    pub trusted_root: Option<String>,
+    pub pkcs12: Option<String>,
+    pub pkcs12_password: Option<String>,
+}
+
+#[derive(Debug, Serialize, Deserialize, Default)]
+pub struct TransportConfig {
+    #[serde(rename = "type")]
+    pub transport_type: TransportType,
+    pub tls: Option<TlsConfig>,
+}
+
+fn default_transport() -> TransportConfig {
+    Default::default()
 }
 
 #[derive(Debug, Serialize, Deserialize, Default)]
@@ -42,6 +59,8 @@ pub struct ClientConfig {
     pub remote_addr: String,
     pub default_token: Option<String>,
     pub services: HashMap<String, ClientServiceConfig>,
+    #[serde(default = "default_transport")]
+    pub transport: TransportConfig,
 }
 
 #[derive(Debug, Serialize, Deserialize, Default)]
@@ -49,6 +68,8 @@ pub struct ServerConfig {
     pub bind_addr: String,
     pub default_token: Option<String>,
     pub services: HashMap<String, ServerServiceConfig>,
+    #[serde(default = "default_transport")]
+    pub transport: TransportConfig,
 }
 
 #[derive(Debug, Serialize, Deserialize)]
@@ -62,32 +83,78 @@ impl Config {
     fn from_str(s: &str) -> Result<Config> {
         let mut config: Config =
             toml::from_str(&s).with_context(|| "Failed to parse the config")?;
+
         if let Some(server) = config.server.as_mut() {
-            for (name, s) in &mut server.services {
-                s.name = name.clone();
+            Config::validate_server_config(server)?;
+        }
+
+        if let Some(client) = config.client.as_mut() {
+            Config::validate_client_config(client)?;
+        }
+
+        if config.server.is_none() && config.client.is_none() {
+            Err(anyhow!("Neither of `[server]` or `[client]` is defined"))
+        } else {
+            Ok(config)
+        }
+    }
+
+    fn validate_server_config(server: &mut ServerConfig) -> Result<()> {
+        // Validate services
+        for (name, s) in &mut server.services {
+            s.name = name.clone();
+            if s.token.is_none() {
+                s.token = server.default_token.clone();
                 if s.token.is_none() {
-                    s.token = server.default_token.clone();
-                    if s.token.is_none() {
-                        bail!("The token of service {} is not set", name);
-                    }
+                    bail!("The token of service {} is not set", name);
                 }
             }
         }
-        if let Some(client) = config.client.as_mut() {
-            for (name, s) in &mut client.services {
-                s.name = name.clone();
+
+        Config::validate_transport_config(&server.transport, true)?;
+
+        Ok(())
+    }
+
+    fn validate_client_config(client: &mut ClientConfig) -> Result<()> {
+        // Validate services
+        for (name, s) in &mut client.services {
+            s.name = name.clone();
+            if s.token.is_none() {
+                s.token = client.default_token.clone();
                 if s.token.is_none() {
-                    s.token = client.default_token.clone();
-                    if s.token.is_none() {
-                        bail!("The token of service {} is not set", name);
-                    }
+                    bail!("The token of service {} is not set", name);
                 }
             }
         }
-        if config.server.is_none() && config.client.is_none() {
-            Err(anyhow!("Neither of `[server]` or `[client]` is defined"))
-        } else {
-            Ok(config)
+
+        Config::validate_transport_config(&client.transport, false)?;
+
+        Ok(())
+    }
+
+    fn validate_transport_config(config: &TransportConfig, is_server: bool) -> Result<()> {
+        match config.transport_type {
+            TransportType::Tcp => Ok(()),
+            TransportType::Tls => {
+                let tls_config = config
+                    .tls
+                    .as_ref()
+                    .ok_or(anyhow!("Missing TLS configuration"))?;
+                if is_server {
+                    tls_config
+                        .pkcs12
+                        .as_ref()
+                        .and(tls_config.pkcs12_password.as_ref())
+                        .ok_or(anyhow!("Missing `pkcs12` or `pkcs12_password`"))?;
+                } else {
+                    tls_config
+                        .trusted_root
+                        .as_ref()
+                        .ok_or(anyhow!("Missing `trusted_root`"))?;
+                }
+                Ok(())
+            }
         }
     }
 

+ 1 - 0
src/lib.rs

@@ -5,6 +5,7 @@ mod helper;
 mod multi_map;
 mod protocol;
 mod server;
+mod transport;
 
 pub use cli::Cli;
 pub use config::Config;

+ 11 - 10
src/protocol.rs

@@ -1,12 +1,9 @@
 pub const HASH_WIDTH_IN_BYTES: usize = 32;
-use anyhow::{Context, Result};
-use bincode;
 
+use anyhow::{Context, Result};
 use lazy_static::lazy_static;
-
 use serde::{Deserialize, Serialize};
-use tokio::io::AsyncReadExt;
-use tokio::net::TcpStream;
+use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite};
 
 type ProtocolVersion = u8;
 const PROTO_V0: u8 = 0u8;
@@ -95,7 +92,7 @@ lazy_static! {
     static ref PACKET_LEN: PacketLength = PacketLength::new();
 }
 
-pub async fn read_hello(conn: &mut TcpStream) -> Result<Hello> {
+pub async fn read_hello<T: AsyncRead + AsyncWrite + Unpin>(conn: &mut T) -> Result<Hello> {
     let mut buf = vec![0u8; PACKET_LEN.hello];
     conn.read_exact(&mut buf)
         .await
@@ -104,7 +101,7 @@ pub async fn read_hello(conn: &mut TcpStream) -> Result<Hello> {
     Ok(hello)
 }
 
-pub async fn read_auth(conn: &mut TcpStream) -> Result<Auth> {
+pub async fn read_auth<T: AsyncRead + AsyncWrite + Unpin>(conn: &mut T) -> Result<Auth> {
     let mut buf = vec![0u8; PACKET_LEN.auth];
     conn.read_exact(&mut buf)
         .await
@@ -112,7 +109,7 @@ pub async fn read_auth(conn: &mut TcpStream) -> Result<Auth> {
     bincode::deserialize(&buf).with_context(|| "Failed to deserialize auth")
 }
 
-pub async fn read_ack(conn: &mut TcpStream) -> Result<Ack> {
+pub async fn read_ack<T: AsyncRead + AsyncWrite + Unpin>(conn: &mut T) -> Result<Ack> {
     let mut bytes = vec![0u8; PACKET_LEN.ack];
     conn.read_exact(&mut bytes)
         .await
@@ -120,7 +117,9 @@ pub async fn read_ack(conn: &mut TcpStream) -> Result<Ack> {
     bincode::deserialize(&bytes).with_context(|| "Failed to deserialize ack")
 }
 
-pub async fn read_control_cmd(conn: &mut TcpStream) -> Result<ControlChannelCmd> {
+pub async fn read_control_cmd<T: AsyncRead + AsyncWrite + Unpin>(
+    conn: &mut T,
+) -> Result<ControlChannelCmd> {
     let mut bytes = vec![0u8; PACKET_LEN.c_cmd];
     conn.read_exact(&mut bytes)
         .await
@@ -128,7 +127,9 @@ pub async fn read_control_cmd(conn: &mut TcpStream) -> Result<ControlChannelCmd>
     bincode::deserialize(&bytes).with_context(|| "Failed to deserialize control cmd")
 }
 
-pub async fn read_data_cmd(conn: &mut TcpStream) -> Result<DataChannelCmd> {
+pub async fn read_data_cmd<T: AsyncRead + AsyncWrite + Unpin>(
+    conn: &mut T,
+) -> Result<DataChannelCmd> {
     let mut bytes = vec![0u8; PACKET_LEN.d_cmd];
     conn.read_exact(&mut bytes)
         .await

+ 90 - 85
src/server.rs

@@ -1,29 +1,25 @@
-use std::collections::HashMap;
-use std::net::SocketAddr;
-use std::sync::Arc;
-use std::time::Duration;
-
-use crate::config::{Config, ServerConfig, ServerServiceConfig};
-use crate::helper::set_tcp_keepalive;
+use crate::config::{Config, ServerConfig, ServerServiceConfig, TransportType};
 use crate::multi_map::MultiMap;
 use crate::protocol::{
-    self, read_hello, Hello, Hello::ControlChannelHello, Hello::DataChannelHello,
+    self, Ack, ControlChannelCmd, DataChannelCmd, Hello, Hello::ControlChannelHello,
+    Hello::DataChannelHello, HASH_WIDTH_IN_BYTES,
 };
-use crate::protocol::{read_auth, Ack, ControlChannelCmd, DataChannelCmd, HASH_WIDTH_IN_BYTES};
+use crate::protocol::{read_auth, read_hello};
+use crate::transport::{TcpTransport, TlsTransport, Transport};
 use anyhow::{anyhow, bail, Context, Result};
+use backoff::{backoff::Backoff, ExponentialBackoff};
 use rand::RngCore;
-use tokio::io::{self, AsyncWriteExt};
+use std::collections::HashMap;
+use std::net::SocketAddr;
+use std::sync::Arc;
+use std::time::Duration;
+use tokio::io::{self, copy_bidirectional, AsyncWriteExt};
+use tokio::net::{TcpListener, TcpStream};
 use tokio::sync::mpsc;
 use tokio::sync::{oneshot, RwLock};
 use tokio::time;
-use tokio::{
-    self,
-    net::{self, TcpListener, TcpStream},
-};
 use tracing::{debug, error, info, info_span, warn, Instrument};
 
-use backoff::{backoff::Backoff, ExponentialBackoff};
-
 type ServiceDigest = protocol::Digest;
 type Nonce = protocol::Digest;
 
@@ -31,43 +27,57 @@ const POOL_SIZE: usize = 64;
 const CHAN_SIZE: usize = 2048;
 
 pub async fn run_server(config: &Config) -> Result<()> {
-    let mut server = Server::from(config)?;
-
-    server.run().await
+    let config = match &config.server {
+            Some(config) => config,
+            None => {
+                return Err(anyhow!("Try to run as a server, but the configuration is missing. Please add the `[server]` block"))
+            }
+        };
+    match config.transport.transport_type {
+        TransportType::Tcp => {
+            let mut server = Server::<TcpTransport>::from(config).await?;
+            server.run().await?;
+        }
+        TransportType::Tls => {
+            let mut server = Server::<TlsTransport>::from(config).await?;
+            server.run().await?;
+        }
+    }
+    Ok(())
 }
 
-type ControlChannelMap = MultiMap<ServiceDigest, Nonce, ControlChannelHandle>;
-struct Server<'a> {
+type ControlChannelMap<T> = MultiMap<ServiceDigest, Nonce, ControlChannelHandle<T>>;
+struct Server<'a, T: Transport> {
     config: &'a ServerConfig,
     services: Arc<RwLock<HashMap<ServiceDigest, ServerServiceConfig>>>,
-    control_channels: Arc<RwLock<ControlChannelMap>>,
+    control_channels: Arc<RwLock<ControlChannelMap<T>>>,
+    transport: Arc<T>,
 }
 
-impl<'a> Server<'a> {
-    pub fn from(config: &'a Config) -> Result<Server> {
-        match &config.server {
-            Some(config) => Ok(Server {
-                config,
-                services: Arc::new(RwLock::new(Server::generate_service_hashmap(config))),
-                control_channels: Arc::new(RwLock::new(ControlChannelMap::new())),
-            }),
-            None =>
-            Err(anyhow!("Try to run as a server, but the configuration is missing. Please add the `[server]` block"))
-        }
+fn generate_service_hashmap(
+    server_config: &ServerConfig,
+) -> HashMap<ServiceDigest, ServerServiceConfig> {
+    let mut ret = HashMap::new();
+    for u in &server_config.services {
+        ret.insert(protocol::digest(u.0.as_bytes()), (*u.1).clone());
     }
+    ret
+}
 
-    fn generate_service_hashmap(
-        server_config: &ServerConfig,
-    ) -> HashMap<ServiceDigest, ServerServiceConfig> {
-        let mut ret = HashMap::new();
-        for u in &server_config.services {
-            ret.insert(protocol::digest(u.0.as_bytes()), (*u.1).clone());
-        }
-        ret
+impl<'a, T: 'static + Transport> Server<'a, T> {
+    pub async fn from(config: &'a ServerConfig) -> Result<Server<'a, T>> {
+        Ok(Server {
+            config,
+            services: Arc::new(RwLock::new(generate_service_hashmap(config))),
+            control_channels: Arc::new(RwLock::new(ControlChannelMap::new())),
+            transport: Arc::new(*(T::new(&config.transport).await?)),
+        })
     }
 
     pub async fn run(&mut self) -> Result<()> {
-        let l = net::TcpListener::bind(&self.config.bind_addr)
+        let l = self
+            .transport
+            .bind(&self.config.bind_addr)
             .await
             .with_context(|| "Failed to listen at `server.bind_addr`")?;
         info!("Listening at {}", self.config.bind_addr);
@@ -82,22 +92,25 @@ impl<'a> Server<'a> {
         // Listen for incoming control or data channels
         loop {
             tokio::select! {
-                ret = l.accept() => {
+                ret = self.transport.accept(&l) => {
                     match ret {
                         Err(err) => {
-                            // Possibly a EMFILE. So sleep for a while and retry
-                            if let Some(d) = backoff.next_backoff() {
-                                error!("Failed to accept: {}. Retry in {:?}...", err, d);
-                                time::sleep(d).await;
-                            } else {
-                                // This branch will never be executed according to the current retry policy
-                                error!("Too many retries. Aborting...");
-                                break;
+                            if let Some(err) = err.downcast_ref::<io::Error>() {
+                                // Possibly a EMFILE. So sleep for a while and retry
+                                if let Some(d) = backoff.next_backoff() {
+                                    error!("Failed to accept: {}. Retry in {:?}...", err, d);
+                                    time::sleep(d).await;
+                                } else {
+                                    // This branch will never be executed according to the current retry policy
+                                    error!("Too many retries. Aborting...");
+                                    break;
+                                }
                             }
                         }
                         Ok((conn, addr)) => {
                             backoff.reset();
                             debug!("Incomming connection from {}", addr);
+
                             let services = self.services.clone();
                             let control_channels = self.control_channels.clone();
                             tokio::spawn(async move {
@@ -119,11 +132,11 @@ impl<'a> Server<'a> {
     }
 }
 
-async fn handle_connection(
-    mut conn: TcpStream,
+async fn handle_connection<T: 'static + Transport>(
+    mut conn: T::Stream,
     addr: SocketAddr,
     services: Arc<RwLock<HashMap<ServiceDigest, ServerServiceConfig>>>,
-    control_channels: Arc<RwLock<ControlChannelMap>>,
+    control_channels: Arc<RwLock<ControlChannelMap<T>>>,
 ) -> Result<()> {
     // Read hello
     let hello = read_hello(&mut conn).await?;
@@ -139,11 +152,11 @@ async fn handle_connection(
     Ok(())
 }
 
-async fn do_control_channel_handshake(
-    mut conn: TcpStream,
+async fn do_control_channel_handshake<T: 'static + Transport>(
+    mut conn: T::Stream,
     addr: SocketAddr,
     services: Arc<RwLock<HashMap<ServiceDigest, ServerServiceConfig>>>,
-    control_channels: Arc<RwLock<ControlChannelMap>>,
+    control_channels: Arc<RwLock<ControlChannelMap<T>>>,
     service_digest: ServiceDigest,
 ) -> Result<()> {
     info!("New control channel incomming from {}", addr);
@@ -219,19 +232,15 @@ async fn do_control_channel_handshake(
     Ok(())
 }
 
-async fn do_data_channel_handshake(
-    conn: TcpStream,
-    control_channels: Arc<RwLock<ControlChannelMap>>,
+async fn do_data_channel_handshake<T: Transport>(
+    conn: T::Stream,
+    control_channels: Arc<RwLock<ControlChannelMap<T>>>,
     nonce: Nonce,
 ) -> Result<()> {
     // Validate
     let control_channels_guard = control_channels.read().await;
     match control_channels_guard.get2(&nonce) {
         Some(c_ch) => {
-            if let Err(e) = set_tcp_keepalive(&conn) {
-                error!("The connection may be unstable! {:?}", e);
-            }
-
             // Send the data channel to the corresponding control channel
             c_ch.conn_pool.data_ch_tx.send(conn).await?;
         }
@@ -242,24 +251,24 @@ async fn do_data_channel_handshake(
     Ok(())
 }
 
-struct ControlChannel {
-    conn: TcpStream,
+struct ControlChannel<T: Transport> {
+    conn: T::Stream,
     service: ServerServiceConfig,
     shutdown_rx: oneshot::Receiver<bool>,
     visitor_tx: mpsc::Sender<TcpStream>,
 }
 
-struct ControlChannelHandle {
+struct ControlChannelHandle<T: Transport> {
     shutdown_tx: oneshot::Sender<bool>,
-    conn_pool: ConnectionPoolHandle,
+    conn_pool: ConnectionPoolHandle<T>,
 }
 
-impl ControlChannelHandle {
-    fn new(conn: TcpStream, service: ServerServiceConfig) -> ControlChannelHandle {
+impl<T: 'static + Transport> ControlChannelHandle<T> {
+    fn new(conn: T::Stream, service: ServerServiceConfig) -> ControlChannelHandle<T> {
         let (shutdown_tx, shutdown_rx) = oneshot::channel::<bool>();
         let name = service.name.clone();
         let conn_pool = ConnectionPoolHandle::new();
-        let actor = ControlChannel {
+        let actor: ControlChannel<T> = ControlChannel {
             conn,
             shutdown_rx,
             service,
@@ -279,13 +288,9 @@ impl ControlChannelHandle {
     }
 }
 
-impl ControlChannel {
+impl<T: Transport> ControlChannel<T> {
     #[tracing::instrument(skip(self), fields(service = %self.service.name))]
     async fn run(mut self) -> Result<()> {
-        if let Err(e) = set_tcp_keepalive(&self.conn) {
-            error!("The connection may be unstable! {:?}", e);
-        }
-
         let l = match TcpListener::bind(&self.service.bind_addr).await {
             Ok(v) => v,
             Err(e) => {
@@ -360,21 +365,21 @@ impl ControlChannel {
 }
 
 #[derive(Debug)]
-struct ConnectionPool {
+struct ConnectionPool<T: Transport> {
     visitor_rx: mpsc::Receiver<TcpStream>,
-    data_ch_rx: mpsc::Receiver<TcpStream>,
+    data_ch_rx: mpsc::Receiver<T::Stream>,
 }
 
-struct ConnectionPoolHandle {
+struct ConnectionPoolHandle<T: Transport> {
     visitor_tx: mpsc::Sender<TcpStream>,
-    data_ch_tx: mpsc::Sender<TcpStream>,
+    data_ch_tx: mpsc::Sender<T::Stream>,
 }
 
-impl ConnectionPoolHandle {
-    fn new() -> ConnectionPoolHandle {
+impl<T: 'static + Transport> ConnectionPoolHandle<T> {
+    fn new() -> ConnectionPoolHandle<T> {
         let (data_ch_tx, data_ch_rx) = mpsc::channel(CHAN_SIZE * 2);
         let (visitor_tx, visitor_rx) = mpsc::channel(CHAN_SIZE);
-        let conn_pool = ConnectionPool {
+        let conn_pool: ConnectionPool<T> = ConnectionPool {
             data_ch_rx,
             visitor_rx,
         };
@@ -388,7 +393,7 @@ impl ConnectionPoolHandle {
     }
 }
 
-impl ConnectionPool {
+impl<T: Transport> ConnectionPool<T> {
     #[tracing::instrument]
     async fn run(mut self) {
         loop {
@@ -397,7 +402,7 @@ impl ConnectionPool {
                     tokio::spawn(async move {
                         let cmd = bincode::serialize(&DataChannelCmd::StartForward).unwrap();
                         if ch.write_all(&cmd).await.is_ok() {
-                            let _ = io::copy_bidirectional(&mut ch, &mut visitor).await;
+                            let _ = copy_bidirectional(&mut ch, &mut visitor).await;
                         }
                     });
                 } else {

+ 25 - 0
src/transport/mod.rs

@@ -0,0 +1,25 @@
+use crate::config::TransportConfig;
+use anyhow::Result;
+use async_trait::async_trait;
+use std::fmt::Debug;
+use std::net::SocketAddr;
+use tokio::{
+    io::{AsyncRead, AsyncWrite},
+    net::ToSocketAddrs,
+};
+
+#[async_trait]
+pub trait Transport: Debug + Send + Sync {
+    type Acceptor: Send + Sync;
+    type Stream: 'static + AsyncRead + AsyncWrite + Unpin + Send + Sync + Debug;
+
+    async fn new(config: &TransportConfig) -> Result<Box<Self>>;
+    async fn bind<T: ToSocketAddrs + Send + Sync>(&self, addr: T) -> Result<Self::Acceptor>;
+    async fn accept(&self, a: &Self::Acceptor) -> Result<(Self::Stream, SocketAddr)>;
+    async fn connect(&self, addr: &String) -> Result<Self::Stream>;
+}
+
+mod tcp;
+mod tls;
+pub use tcp::TcpTransport;
+pub use tls::TlsTransport;

+ 41 - 0
src/transport/tcp.rs

@@ -0,0 +1,41 @@
+use crate::{config::TransportConfig, helper::set_tcp_keepalive};
+
+use super::Transport;
+use anyhow::Result;
+use async_trait::async_trait;
+use std::net::SocketAddr;
+use tokio::net::{TcpListener, TcpStream, ToSocketAddrs};
+use tracing::error;
+
+#[derive(Debug)]
+pub struct TcpTransport {}
+
+#[async_trait]
+impl Transport for TcpTransport {
+    type Acceptor = TcpListener;
+    type Stream = TcpStream;
+
+    async fn new(_config: &TransportConfig) -> Result<Box<Self>> {
+        Ok(Box::new(TcpTransport {}))
+    }
+
+    async fn bind<T: ToSocketAddrs + Send + Sync>(&self, addr: T) -> Result<Self::Acceptor> {
+        Ok(TcpListener::bind(addr).await?)
+    }
+
+    async fn accept(&self, a: &Self::Acceptor) -> Result<(Self::Stream, SocketAddr)> {
+        let (s, addr) = a.accept().await?;
+        Ok((s, addr))
+    }
+
+    async fn connect(&self, addr: &String) -> Result<Self::Stream> {
+        let s = TcpStream::connect(addr).await?;
+        if let Err(e) = set_tcp_keepalive(&s) {
+            error!(
+                "Failed to set TCP keepalive. The connection maybe unstable: {:?}",
+                e
+            );
+        }
+        Ok(s)
+    }
+}

+ 96 - 0
src/transport/tls.rs

@@ -0,0 +1,96 @@
+use std::net::SocketAddr;
+
+use super::Transport;
+use crate::{
+    config::{TlsConfig, TransportConfig},
+    helper::set_tcp_keepalive,
+};
+use anyhow::{anyhow, Context, Result};
+use async_trait::async_trait;
+use tokio::{
+    fs,
+    net::{TcpListener, TcpStream, ToSocketAddrs},
+};
+use tokio_native_tls::{
+    native_tls::{self, Certificate, Identity},
+    TlsAcceptor, TlsConnector, TlsStream,
+};
+use tracing::error;
+
+#[derive(Debug)]
+pub struct TlsTransport {
+    config: TlsConfig,
+    connector: Option<TlsConnector>,
+}
+
+#[async_trait]
+impl Transport for TlsTransport {
+    type Acceptor = (TcpListener, TlsAcceptor);
+    type Stream = TlsStream<TcpStream>;
+
+    async fn new(config: &TransportConfig) -> Result<Box<Self>> {
+        let config = match &config.tls {
+            Some(v) => v,
+            None => {
+                return Err(anyhow!("Missing tls config"));
+            }
+        };
+
+        let connector = match config.trusted_root.as_ref() {
+            Some(path) => {
+                let s = fs::read_to_string(path).await?;
+                let cert = Certificate::from_pem(&s.as_bytes())?;
+                let connector = native_tls::TlsConnector::builder()
+                    .add_root_certificate(cert)
+                    .build()?;
+                Some(TlsConnector::from(connector))
+            }
+            None => None,
+        };
+
+        Ok(Box::new(TlsTransport {
+            config: config.clone(),
+            connector,
+        }))
+    }
+
+    async fn bind<A: ToSocketAddrs + Send + Sync>(&self, addr: A) -> Result<Self::Acceptor> {
+        let ident = Identity::from_pkcs12(
+            &fs::read(self.config.pkcs12.as_ref().unwrap()).await?,
+            self.config.pkcs12_password.as_ref().unwrap(),
+        )
+        .with_context(|| "Failed to create identitiy")?;
+        let l = TcpListener::bind(addr)
+            .await
+            .with_context(|| "Failed to create tcp listener")?;
+        let t = TlsAcceptor::from(native_tls::TlsAcceptor::new(ident).unwrap());
+        Ok((l, t))
+    }
+
+    async fn accept(&self, a: &Self::Acceptor) -> Result<(Self::Stream, SocketAddr)> {
+        let (conn, addr) = a.0.accept().await?;
+        let conn = a.1.accept(conn).await?;
+
+        Ok((conn, addr))
+    }
+
+    async fn connect(&self, addr: &String) -> Result<Self::Stream> {
+        let conn = TcpStream::connect(&addr).await?;
+        if let Err(e) = set_tcp_keepalive(&conn) {
+            error!(
+                "Failed to set TCP keepalive. The connection maybe unstable: {:?}",
+                e
+            );
+        }
+        let connector = self.connector.as_ref().unwrap();
+        Ok(connector
+            .connect(
+                self.config
+                    .hostname
+                    .as_ref()
+                    .unwrap_or(&String::from(addr.split(':').next().unwrap())),
+                conn,
+            )
+            .await?)
+    }
+}