tls.rs 3.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293
  1. use std::net::SocketAddr;
  2. use super::Transport;
  3. use crate::config::{TlsConfig, TransportConfig};
  4. use crate::helper::set_tcp_keepalive;
  5. use anyhow::{anyhow, Context, Result};
  6. use async_trait::async_trait;
  7. use tokio::fs;
  8. use tokio::net::{TcpListener, TcpStream, ToSocketAddrs};
  9. use tokio_native_tls::native_tls::{self, Certificate, Identity};
  10. use tokio_native_tls::{TlsAcceptor, TlsConnector, TlsStream};
  11. use tracing::error;
  12. #[derive(Debug)]
  13. pub struct TlsTransport {
  14. config: TlsConfig,
  15. connector: Option<TlsConnector>,
  16. }
  17. #[async_trait]
  18. impl Transport for TlsTransport {
  19. type Acceptor = (TcpListener, TlsAcceptor);
  20. type Stream = TlsStream<TcpStream>;
  21. async fn new(config: &TransportConfig) -> Result<Box<Self>> {
  22. let config = match &config.tls {
  23. Some(v) => v,
  24. None => {
  25. return Err(anyhow!("Missing tls config"));
  26. }
  27. };
  28. let connector = match config.trusted_root.as_ref() {
  29. Some(path) => {
  30. let s = fs::read_to_string(path)
  31. .await
  32. .with_context(|| "Failed to read the `tls.trusted_root`")?;
  33. let cert = Certificate::from_pem(s.as_bytes())
  34. .with_context(|| "Failed to read certificate from `tls.trusted_root`")?;
  35. let connector = native_tls::TlsConnector::builder()
  36. .add_root_certificate(cert)
  37. .build()?;
  38. Some(TlsConnector::from(connector))
  39. }
  40. None => None,
  41. };
  42. Ok(Box::new(TlsTransport {
  43. config: config.clone(),
  44. connector,
  45. }))
  46. }
  47. async fn bind<A: ToSocketAddrs + Send + Sync>(&self, addr: A) -> Result<Self::Acceptor> {
  48. let ident = Identity::from_pkcs12(
  49. &fs::read(self.config.pkcs12.as_ref().unwrap()).await?,
  50. self.config.pkcs12_password.as_ref().unwrap(),
  51. )
  52. .with_context(|| "Failed to create identitiy")?;
  53. let l = TcpListener::bind(addr)
  54. .await
  55. .with_context(|| "Failed to create tcp listener")?;
  56. let t = TlsAcceptor::from(native_tls::TlsAcceptor::new(ident).unwrap());
  57. Ok((l, t))
  58. }
  59. async fn accept(&self, a: &Self::Acceptor) -> Result<(Self::Stream, SocketAddr)> {
  60. let (conn, addr) = a.0.accept().await?;
  61. let conn = a.1.accept(conn).await?;
  62. Ok((conn, addr))
  63. }
  64. async fn connect(&self, addr: &str) -> Result<Self::Stream> {
  65. let conn = TcpStream::connect(&addr).await?;
  66. if let Err(e) = set_tcp_keepalive(&conn) {
  67. error!(
  68. "Failed to set TCP keepalive. The connection maybe unstable: {:?}",
  69. e
  70. );
  71. }
  72. let connector = self.connector.as_ref().unwrap();
  73. Ok(connector
  74. .connect(
  75. self.config
  76. .hostname
  77. .as_ref()
  78. .unwrap_or(&String::from(addr.split(':').next().unwrap())),
  79. conn,
  80. )
  81. .await?)
  82. }
  83. }