tls.rs 2.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990
  1. use std::net::SocketAddr;
  2. use super::Transport;
  3. use crate::config::{TlsConfig, TransportConfig};
  4. use crate::helper::set_tcp_keepalive;
  5. use anyhow::{anyhow, Context, Result};
  6. use async_trait::async_trait;
  7. use tokio::fs;
  8. use tokio::net::{TcpListener, TcpStream, ToSocketAddrs};
  9. use tokio_native_tls::native_tls::{self, Certificate, Identity};
  10. use tokio_native_tls::{TlsAcceptor, TlsConnector, TlsStream};
  11. #[derive(Debug)]
  12. pub struct TlsTransport {
  13. config: TlsConfig,
  14. connector: Option<TlsConnector>,
  15. }
  16. #[async_trait]
  17. impl Transport for TlsTransport {
  18. type Acceptor = (TcpListener, TlsAcceptor);
  19. type Stream = TlsStream<TcpStream>;
  20. async fn new(config: &TransportConfig) -> Result<Self> {
  21. let config = match &config.tls {
  22. Some(v) => v,
  23. None => {
  24. return Err(anyhow!("Missing tls config"));
  25. }
  26. };
  27. let connector = match config.trusted_root.as_ref() {
  28. Some(path) => {
  29. let s = fs::read_to_string(path)
  30. .await
  31. .with_context(|| "Failed to read the `tls.trusted_root`")?;
  32. let cert = Certificate::from_pem(s.as_bytes())
  33. .with_context(|| "Failed to read certificate from `tls.trusted_root`")?;
  34. let connector = native_tls::TlsConnector::builder()
  35. .add_root_certificate(cert)
  36. .build()?;
  37. Some(TlsConnector::from(connector))
  38. }
  39. None => None,
  40. };
  41. Ok(TlsTransport {
  42. config: config.clone(),
  43. connector,
  44. })
  45. }
  46. async fn bind<A: ToSocketAddrs + Send + Sync>(&self, addr: A) -> Result<Self::Acceptor> {
  47. let ident = Identity::from_pkcs12(
  48. &fs::read(self.config.pkcs12.as_ref().unwrap()).await?,
  49. self.config.pkcs12_password.as_ref().unwrap(),
  50. )
  51. .with_context(|| "Failed to create identitiy")?;
  52. let l = TcpListener::bind(addr)
  53. .await
  54. .with_context(|| "Failed to create tcp listener")?;
  55. let t = TlsAcceptor::from(native_tls::TlsAcceptor::new(ident).unwrap());
  56. Ok((l, t))
  57. }
  58. async fn accept(&self, a: &Self::Acceptor) -> Result<(Self::Stream, SocketAddr)> {
  59. let (conn, addr) = a.0.accept().await?;
  60. set_tcp_keepalive(&conn);
  61. let conn = a.1.accept(conn).await?;
  62. Ok((conn, addr))
  63. }
  64. async fn connect(&self, addr: &str) -> Result<Self::Stream> {
  65. let conn = TcpStream::connect(&addr).await?;
  66. set_tcp_keepalive(&conn);
  67. let connector = self.connector.as_ref().unwrap();
  68. Ok(connector
  69. .connect(
  70. self.config
  71. .hostname
  72. .as_ref()
  73. .unwrap_or(&String::from(addr.split(':').next().unwrap())),
  74. conn,
  75. )
  76. .await?)
  77. }
  78. }