Parcourir la source

feat: mask out token in logging (#129)

Yujia Qiao il y a 4 ans
Parent
commit
5f301ed8e3
1 fichiers modifiés avec 52 ajouts et 10 suppressions
  1. 52 10
      src/config.rs

+ 52 - 10
src/config.rs

@@ -1,11 +1,49 @@
 use anyhow::{anyhow, bail, Context, Result};
 use serde::{Deserialize, Serialize};
 use std::collections::HashMap;
+use std::fmt::{Debug, Formatter};
+use std::ops::Deref;
 use std::path::Path;
 use tokio::fs;
 
 use crate::transport::{DEFAULT_KEEPALIVE_INTERVAL, DEFAULT_KEEPALIVE_SECS, DEFAULT_NODELAY};
 
+/// String with Debug implementation that emits "MASKED"
+/// Used to mask sensitive strings when logging
+#[derive(Serialize, Deserialize, Default, PartialEq, Clone)]
+pub struct MaskedString(String);
+
+impl Debug for MaskedString {
+    fn fmt(&self, f: &mut Formatter<'_>) -> std::result::Result<(), std::fmt::Error> {
+        f.write_str("MASKED")
+    }
+}
+
+impl Deref for MaskedString {
+    type Target = String;
+    fn deref(&self) -> &Self::Target {
+        &self.0
+    }
+}
+
+impl AsRef<[u8]> for MaskedString {
+    fn as_ref(&self) -> &[u8] {
+        self.0.as_bytes()
+    }
+}
+
+impl From<&str> for MaskedString {
+    fn from(s: &str) -> MaskedString {
+        MaskedString(String::from(s))
+    }
+}
+
+impl From<MaskedString> for String {
+    fn from(s: MaskedString) -> String {
+        s.0
+    }
+}
+
 #[derive(Debug, Serialize, Deserialize, Copy, Clone, PartialEq)]
 pub enum TransportType {
     #[serde(rename = "tcp")]
@@ -30,7 +68,7 @@ pub struct ClientServiceConfig {
     #[serde(skip)]
     pub name: String,
     pub local_addr: String,
-    pub token: Option<String>,
+    pub token: Option<MaskedString>,
     pub nodelay: Option<bool>,
 }
 
@@ -69,7 +107,7 @@ pub struct ServerServiceConfig {
     #[serde(skip)]
     pub name: String,
     pub bind_addr: String,
-    pub token: Option<String>,
+    pub token: Option<MaskedString>,
     pub nodelay: Option<bool>,
 }
 
@@ -87,7 +125,7 @@ pub struct TlsConfig {
     pub hostname: Option<String>,
     pub trusted_root: Option<String>,
     pub pkcs12: Option<String>,
-    pub pkcs12_password: Option<String>,
+    pub pkcs12_password: Option<MaskedString>,
 }
 
 fn default_noise_pattern() -> String {
@@ -99,7 +137,7 @@ fn default_noise_pattern() -> String {
 pub struct NoiseConfig {
     #[serde(default = "default_noise_pattern")]
     pub pattern: String,
-    pub local_private_key: Option<String>,
+    pub local_private_key: Option<MaskedString>,
     pub remote_public_key: Option<String>,
     // TODO: Maybe psk can be added
 }
@@ -152,7 +190,7 @@ fn default_transport() -> TransportConfig {
 #[serde(deny_unknown_fields)]
 pub struct ClientConfig {
     pub remote_addr: String,
-    pub default_token: Option<String>,
+    pub default_token: Option<MaskedString>,
     pub services: HashMap<String, ClientServiceConfig>,
     #[serde(default = "default_transport")]
     pub transport: TransportConfig,
@@ -162,7 +200,7 @@ pub struct ClientConfig {
 #[serde(deny_unknown_fields)]
 pub struct ServerConfig {
     pub bind_addr: String,
-    pub default_token: Option<String>,
+    pub default_token: Option<MaskedString>,
     pub services: HashMap<String, ServerServiceConfig>,
     #[serde(default = "default_transport")]
     pub transport: TransportConfig,
@@ -353,7 +391,8 @@ mod tests {
                 .unwrap()
                 .token
                 .as_ref()
-                .unwrap(),
+                .unwrap()
+                .0,
             "123"
         );
 
@@ -367,7 +406,8 @@ mod tests {
                 .unwrap()
                 .token
                 .as_ref()
-                .unwrap(),
+                .unwrap()
+                .0,
             "4"
         );
         Ok(())
@@ -401,7 +441,8 @@ mod tests {
                 .unwrap()
                 .token
                 .as_ref()
-                .unwrap(),
+                .unwrap()
+                .0,
             "123"
         );
 
@@ -415,7 +456,8 @@ mod tests {
                 .unwrap()
                 .token
                 .as_ref()
-                .unwrap(),
+                .unwrap()
+                .0,
             "4"
         );
         Ok(())