Explorar o código

refactor: refactor handle_connection

Yujia Qiao %!s(int64=4) %!d(string=hai) anos
pai
achega
f4b7e600bc
Modificáronse 1 ficheiros con 98 adicións e 76 borrados
  1. 98 76
      src/server.rs

+ 98 - 76
src/server.rs

@@ -129,92 +129,114 @@ async fn handle_connection(
     let hello = read_hello(&mut conn).await?;
     match hello {
         ControlChannelHello(_, service_digest) => {
-            info!("New control channel incomming from {}", addr);
-
-            // Generate a nonce
-            let mut nonce = vec![0u8; HASH_WIDTH_IN_BYTES];
-            rand::thread_rng().fill_bytes(&mut nonce);
-
-            // Send hello
-            let hello_send = Hello::ControlChannelHello(
-                protocol::CURRENT_PROTO_VRESION,
-                nonce.clone().try_into().unwrap(),
-            );
-            conn.write_all(&bincode::serialize(&hello_send).unwrap())
+            do_control_channel_handshake(conn, addr, services, control_channels, service_digest)
                 .await?;
+        }
+        DataChannelHello(_, nonce) => {
+            do_data_channel_handshake(conn, control_channels, nonce).await?;
+        }
+    }
+    Ok(())
+}
 
-            // Lookup the service
-            let services_guard = services.read().await;
-            let service_config = match services_guard.get(&service_digest) {
-                Some(v) => v,
-                None => {
-                    conn.write_all(&bincode::serialize(&Ack::ServiceNotExist).unwrap())
-                        .await?;
-                    bail!("No such a service {}", hex::encode(&service_digest));
-                }
-            };
-            let service_name = &service_config.name;
-
-            // Calculate the checksum
-            let mut concat = Vec::from(service_config.token.as_ref().unwrap().as_bytes());
-            concat.append(&mut nonce);
-
-            // Read auth
-            let d = match read_auth(&mut conn).await? {
-                protocol::Auth(v) => v,
-            };
+async fn do_control_channel_handshake(
+    mut conn: TcpStream,
+    addr: SocketAddr,
+    services: Arc<RwLock<HashMap<ServiceDigest, ServerServiceConfig>>>,
+    control_channels: Arc<RwLock<ControlChannelMap>>,
+    service_digest: ServiceDigest,
+) -> Result<()> {
+    info!("New control channel incomming from {}", addr);
+
+    // Generate a nonce
+    let mut nonce = vec![0u8; HASH_WIDTH_IN_BYTES];
+    rand::thread_rng().fill_bytes(&mut nonce);
+
+    // Send hello
+    let hello_send = Hello::ControlChannelHello(
+        protocol::CURRENT_PROTO_VRESION,
+        nonce.clone().try_into().unwrap(),
+    );
+    conn.write_all(&bincode::serialize(&hello_send).unwrap())
+        .await?;
+
+    // Lookup the service
+    let services_guard = services.read().await;
+    let service_config = match services_guard.get(&service_digest) {
+        Some(v) => v,
+        None => {
+            conn.write_all(&bincode::serialize(&Ack::ServiceNotExist).unwrap())
+                .await?;
+            bail!("No such a service {}", hex::encode(&service_digest));
+        }
+    };
+    let service_name = &service_config.name;
+
+    // Calculate the checksum
+    let mut concat = Vec::from(service_config.token.as_ref().unwrap().as_bytes());
+    concat.append(&mut nonce);
+
+    // Read auth
+    let d = match read_auth(&mut conn).await? {
+        protocol::Auth(v) => v,
+    };
+
+    // Validate
+    let session_key = protocol::digest(&concat);
+    if session_key != d {
+        conn.write_all(&bincode::serialize(&Ack::AuthFailed).unwrap())
+            .await?;
+        debug!(
+            "Expect {}, but got {}",
+            hex::encode(session_key),
+            hex::encode(d)
+        );
+        bail!("Service {} failed the authentication", service_name);
+    } else {
+        let mut h = control_channels.write().await;
+
+        if let Some(_) = h.remove1(&service_digest) {
+            warn!(
+                "Dropping previous control channel for digest {}",
+                hex::encode(service_digest)
+            );
+        }
 
-            // Validate
-            let session_key = protocol::digest(&concat);
-            if session_key != d {
-                conn.write_all(&bincode::serialize(&Ack::AuthFailed).unwrap())
-                    .await?;
-                debug!(
-                    "Expect {}, but got {}",
-                    hex::encode(session_key),
-                    hex::encode(d)
-                );
-                bail!("Service {} failed the authentication", service_name);
-            } else {
-                let mut h = control_channels.write().await;
+        let service_config = service_config.clone();
+        drop(services_guard);
 
-                if let Some(_) = h.remove1(&service_digest) {
-                    warn!(
-                        "Dropping previous control channel for digest {}",
-                        hex::encode(service_digest)
-                    );
-                }
+        // Send ack
+        conn.write_all(&bincode::serialize(&Ack::Ok).unwrap())
+            .await?;
 
-                let service_config = service_config.clone();
-                drop(services_guard);
+        info!(service = %service_config.name, "Control channel established");
+        let handle = ControlChannelHandle::new(conn, service_config);
 
-                // Send ack
-                conn.write_all(&bincode::serialize(&Ack::Ok).unwrap())
-                    .await?;
+        // Drop the old handle
+        let _ = h.insert(service_digest, session_key, handle);
+    }
 
-                info!(service = %service_config.name, "Control channel established");
-                let handle = ControlChannelHandle::new(conn, service_config);
+    Ok(())
+}
 
-                // Drop the old handle
-                let _ = h.insert(service_digest, session_key, handle);
+async fn do_data_channel_handshake(
+    conn: TcpStream,
+    control_channels: Arc<RwLock<ControlChannelMap>>,
+    nonce: Nonce,
+) -> Result<()> {
+    // Validate
+    let control_channels_guard = control_channels.read().await;
+    match control_channels_guard.get2(&nonce) {
+        Some(c_ch) => {
+            if let Err(e) = set_tcp_keepalive(&conn) {
+                error!("The connection may be unstable! {:?}", e);
             }
-        }
-        DataChannelHello(_, nonce) => {
-            // Validate
-            let control_channels_guard = control_channels.read().await;
-            match control_channels_guard.get2(&nonce) {
-                Some(c_ch) => {
-                    if let Err(e) = set_tcp_keepalive(&conn) {
-                        error!("The connection may be unstable! {:?}", e);
-                    }
 
-                    // Send the data channel to the corresponding control channel
-                    c_ch.conn_pool.data_ch_tx.send(conn).await?;
-                }
-                None => {
-                    warn!("Data channel has incorrect nonce");
-                }
-            }
+            // Send the data channel to the corresponding control channel
+            c_ch.conn_pool.data_ch_tx.send(conn).await?;
+        }
+        None => {
+            warn!("Data channel has incorrect nonce");
         }
     }
     Ok(())