Parcourir la source

Fix DNS timeout handling and remove starup logs

ajasibley il y a 5 mois
Parent
commit
fbc96bf562
6 fichiers modifiés avec 85 ajouts et 44 suppressions
  1. 2 2
      README.md
  2. 4 18
      chat.go
  3. 11 2
      cmd/selftest/main.go
  4. 68 18
      dns.go
  5. 0 2
      http.go
  6. 0 2
      ssh.go

+ 2 - 2
README.md

@@ -14,7 +14,7 @@ curl ch.at/what-is-rust         # Path-based (cleaner URLs, hyphens become space
 ssh ch.at
 
 # DNS tunneling
-dig what-is-2+2.ch.at TXT
+dig @ch.at "what-is-2+2" TXT
 
 # API (OpenAI-compatible)
 curl ch.at/v1/chat/completions
@@ -146,7 +146,7 @@ Edit constants in source files:
 
 ## Limitations
 
-- **DNS**: Responses limited to ~500 chars due to protocol constraints
+- **DNS**: Responses limited to ~500 bytes. Complex queries may time out after 4s. DNS queries automatically request concise, plain-text responses
 - **History**: Limited to 2KB in web interface to prevent URL overflow
 - **Rate limiting**: Basic IP-based limiting to prevent abuse
 - **No encryption**: SSH is encrypted, but HTTP/DNS are not

+ 4 - 18
chat.go

@@ -1,10 +1,5 @@
 package main
 
-import (
-	"fmt"
-	"os"
-)
-
 // Configuration - edit source code and recompile to change settings
 // To disable a service: set its port to 0 or delete its .go file
 const (
@@ -19,18 +14,14 @@ func main() {
 	// SSH Server
 	if SSH_PORT > 0 {
 		go func() {
-			if err := StartSSHServer(SSH_PORT); err != nil {
-				fmt.Fprintf(os.Stderr, "SSH server error: %v\n", err)
-			}
+			StartSSHServer(SSH_PORT)
 		}()
 	}
 	
 	// DNS Server
 	if DNS_PORT > 0 {
 		go func() {
-			if err := StartDNSServer(DNS_PORT); err != nil {
-				fmt.Fprintf(os.Stderr, "DNS server error: %v\n", err)
-			}
+			StartDNSServer(DNS_PORT)
 		}()
 	}
 	
@@ -39,17 +30,12 @@ func main() {
 	if HTTP_PORT > 0 || HTTPS_PORT > 0 {
 		if HTTPS_PORT > 0 {
 			go func() {
-				if err := StartHTTPSServer(HTTPS_PORT, "cert.pem", "key.pem"); err != nil {
-					fmt.Fprintf(os.Stderr, "HTTPS server error: %v\n", err)
-				}
+				StartHTTPSServer(HTTPS_PORT, "cert.pem", "key.pem")
 			}()
 		}
 		
 		if HTTP_PORT > 0 {
-			if err := StartHTTPServer(HTTP_PORT); err != nil {
-				fmt.Fprintf(os.Stderr, "HTTP server error: %v\n", err)
-				os.Exit(1)
-			}
+			StartHTTPServer(HTTP_PORT)
 		} else {
 			// If only HTTPS is enabled, block forever
 			select {}

+ 11 - 2
cmd/selftest/main.go

@@ -96,6 +96,8 @@ func main() {
 
 	baseURL := strings.TrimSuffix(os.Args[1], "/")
 	
+	sshPort := "22"
+	
 	// Extract hostname from URL for SSH/DNS tests
 	hostname := "localhost"
 	if u, err := url.Parse(baseURL); err == nil && u.Hostname() != "" {
@@ -225,7 +227,7 @@ func main() {
 		sshHost = "127.0.0.1"
 	}
 	
-	sshClient, err := ssh.Dial("tcp", sshHost+":22", config)
+	sshClient, err := ssh.Dial("tcp", sshHost+":"+sshPort, config)
 	if err == nil {
 		defer sshClient.Close()
 		
@@ -317,7 +319,14 @@ func main() {
 	// Test 7: DNS protocol
 	fmt.Print("Testing DNS protocol... ")
 	// Run dig command to query the DNS server
-	cmd := exec.Command("dig", "+short", "@127.0.0.1", "-p", "53", "repeat-verbatim-the-word-pass.ch.at", "TXT")
+	// For localhost, use the query directly without domain suffix
+	var queryDomain string
+	if hostname == "localhost" || hostname == "127.0.0.1" {
+		queryDomain = "repeat-verbatim-the-word-pass"
+	} else {
+		queryDomain = "repeat-verbatim-the-word-pass." + hostname
+	}
+	cmd := exec.Command("dig", "+short", "@"+hostname, "-p", "53", queryDomain, "TXT")
 	output, err := cmd.Output()
 	if err != nil {
 		fmt.Printf("✗ (dig command failed: %v)\n", err)

+ 68 - 18
dns.go

@@ -3,58 +3,109 @@ package main
 import (
 	"fmt"
 	"strings"
+	"time"
 
 	"github.com/miekg/dns"
 )
 
 func StartDNSServer(port int) error {
-	// Set up DNS handler
 	dns.HandleFunc("ch.at.", handleDNS)
-	dns.HandleFunc(".", handleDNS) // Catch-all for any domain
+	dns.HandleFunc(".", handleDNS)
 
-	// Create and start server
 	server := &dns.Server{
 		Addr: fmt.Sprintf(":%d", port),
 		Net:  "udp",
 	}
 
-	fmt.Printf("DNS server listening on :%d\n", port)
 	return server.ListenAndServe()
 }
 
 func handleDNS(w dns.ResponseWriter, r *dns.Msg) {
-	// Rate limiting
 	if !rateLimitAllow(w.RemoteAddr().String()) {
-		return // Silently drop - DNS doesn't have error responses for rate limits
+		return
 	}
 
-	// Check if we have a question
 	if len(r.Question) == 0 {
 		return
 	}
 
-	// Build response
 	m := new(dns.Msg)
 	m.SetReply(r)
 	m.Authoritative = true
 
-	// Process each question (usually just one)
 	for _, q := range r.Question {
 		if q.Qtype != dns.TypeTXT {
-			continue // Only handle TXT queries
+			continue
 		}
 
-		// Extract the prompt from domain name
 		name := strings.TrimSuffix(strings.TrimSuffix(q.Name, "."), ".ch.at")
 		prompt := strings.ReplaceAll(name, "-", " ")
+		
+		
+		// Optimize prompt for DNS constraints
+		dnsPrompt := "Answer in 500 characters or less, no markdown formatting: " + prompt
+
+		// Stream LLM response with hard deadline
+		ch := make(chan string)
+		done := make(chan bool)
+		
+		go func() {
+			if _, err := LLM(dnsPrompt, ch); err != nil {
+				select {
+				case ch <- "Error: " + err.Error():
+				case <-done:
+				}
+			}
+			// Don't close ch here - LLM function already does it with defer
+		}()
 
-		// Get LLM response
-		response, err := LLM(prompt, nil)
-		if err != nil {
-			response = err.Error()
+		var response strings.Builder
+		deadline := time.After(4 * time.Second) // Safe middle ground for DNS clients
+		channelClosed := false
+		
+		
+		for {
+			select {
+			case chunk, ok := <-ch:
+				if !ok {
+					channelClosed = true
+					goto respond
+				}
+				response.WriteString(chunk)
+				if response.Len() >= 500 {
+					goto respond
+				}
+			case <-deadline:
+				if response.Len() == 0 {
+					response.WriteString("Request timed out")
+				} else if !channelClosed {
+					response.WriteString("... (incomplete)")
+				}
+				goto respond
+			}
 		}
 
-		// Create TXT record
+	respond:
+		close(done)
+		finalResponse := response.String()
+		if len(finalResponse) > 500 {
+			finalResponse = finalResponse[:497] + "..."
+		} else if len(finalResponse) == 500 && !channelClosed {
+			// We hit the exact limit but stream is still going
+			finalResponse = finalResponse[:497] + "..."
+		}
+		
+
+		// Split response into 255-byte chunks for DNS TXT records
+		var txtStrings []string
+		for i := 0; i < len(finalResponse); i += 255 {
+			end := i + 255
+			if end > len(finalResponse) {
+				end = len(finalResponse)
+			}
+			txtStrings = append(txtStrings, finalResponse[i:end])
+		}
+		
 		txt := &dns.TXT{
 			Hdr: dns.RR_Header{
 				Name:   q.Name,
@@ -62,11 +113,10 @@ func handleDNS(w dns.ResponseWriter, r *dns.Msg) {
 				Class:  dns.ClassINET,
 				Ttl:    60,
 			},
-			Txt: []string{response},
+			Txt: txtStrings,
 		}
 		m.Answer = append(m.Answer, txt)
 	}
 
-	// Send response
 	w.WriteMsg(m)
 }

+ 0 - 2
http.go

@@ -39,13 +39,11 @@ func StartHTTPServer(port int) error {
 	http.HandleFunc("/v1/chat/completions", handleChatCompletions)
 
 	addr := fmt.Sprintf(":%d", port)
-	fmt.Printf("HTTP server listening on %s\n", addr)
 	return http.ListenAndServe(addr, nil)
 }
 
 func StartHTTPSServer(port int, certFile, keyFile string) error {
 	addr := fmt.Sprintf(":%d", port)
-	fmt.Printf("HTTPS server listening on %s\n", addr)
 	return http.ListenAndServeTLS(addr, certFile, keyFile, nil)
 }
 

+ 0 - 2
ssh.go

@@ -30,8 +30,6 @@ func StartSSHServer(port int) error {
 	}
 	defer listener.Close()
 
-	fmt.Printf("SSH server listening on :%d\n", port)
-
 	// Simple connection limiting
 	sem := make(chan struct{}, 100) // Max 100 concurrent SSH connections