Explorar el Código

Replace manual implementations with standard libraries for DNS and rate limiting, add real streaming support, and fix OpenAI compatibility

ajasibley hace 6 meses
padre
commit
d50144214f
Se han modificado 11 ficheros con 739 adiciones y 674 borrados
  1. 36 52
      README.md
  2. 9 45
      chat.go
  3. 277 41
      cmd/selftest/main.go
  4. 44 181
      dns.go
  5. 15 3
      go.mod
  6. 20 6
      go.sum
  7. 151 38
      http.go
  8. 138 0
      llm.go.example
  9. 0 147
      openai.go
  10. 23 57
      ssh.go
  11. 26 104
      util.go

+ 36 - 52
README.md

@@ -9,23 +9,25 @@ A lightweight language model chat service accessible through HTTP, SSH, DNS, and
 open https://ch.at
 open https://ch.at
 
 
 # Terminal
 # Terminal
-curl ch.at/?q=hello
+curl ch.at/?q=hello             # Query parameter (handles special chars)
+curl ch.at/what-is-rust         # Path-based (cleaner URLs, hyphens become spaces)
 ssh ch.at
 ssh ch.at
 
 
 # DNS tunneling
 # DNS tunneling
 dig what-is-2+2.ch.at TXT
 dig what-is-2+2.ch.at TXT
 
 
 # API (OpenAI-compatible)
 # API (OpenAI-compatible)
-curl ch.at:8080/v1/chat/completions
+curl ch.at/v1/chat/completions
 ```
 ```
 
 
 ## Design
 ## Design
 
 
-- ~1,100 lines of Go, one external dependency
+- ~1,200 lines of Go, three dependencies
 - Single static binary
 - Single static binary
 - No accounts, no logs, no tracking
 - No accounts, no logs, no tracking
 - Configuration through source code (edit and recompile)
 - Configuration through source code (edit and recompile)
 
 
+
 ## Privacy
 ## Privacy
 
 
 Privacy by design:
 Privacy by design:
@@ -41,72 +43,54 @@ Privacy by design:
 
 
 ## Installation
 ## Installation
 
 
-Create `llm.go` (gitignored):
+### Quick Start
 
 
-```go
-// llm.go - Create this file (it's gitignored)
-package main
-
-import (
-	"bufio"
-	"bytes"
-	"context"
-	"encoding/json"
-	"fmt"
-	"io"
-	"net/http"
-	"strings"
-)
+```bash
+# Copy the example LLM configuration (llm.go is gitignored)
+cp llm.go.example llm.go
 
 
-func getLLMResponse(ctx context.Context, prompt string) (string, error) {
-	var response strings.Builder
-	stream, err := getLLMResponseStream(ctx, prompt)
-	if err != nil {
-		return "", err
-	}
-	for chunk := range stream {
-		response.WriteString(chunk)
-	}
-	return response.String(), nil
-}
-
-func getLLMResponseStream(ctx context.Context, prompt string) (<-chan string, error) {
-	endpoint := "https://api.openai.com/v1/chat/completions"
-	key := "YOUR-OPENAI-API-KEY-HERE"  // Replace with your key
-	
-	payload := map[string]interface{}{
-		"model": "gpt-4o",
-		"messages": []map[string]string{
-			{"role": "user", "content": prompt},
-		},
-		"stream": true,
-	}
-	
-	// ... rest of implementation
-}
-```
+# Edit llm.go and add your API key
+# Supports OpenAI, Anthropic Claude, or local models (Ollama)
 
 
-Then build:
-```bash
+# For HTTPS, you'll need cert.pem and key.pem files:
+# Option 1: Use Let's Encrypt (recommended for production)
+# Option 2: Use your existing certificates
+# Option 3: Self-signed for testing:
+#   openssl req -x509 -newkey rsa:4096 -keyout key.pem -out cert.pem -days 365 -nodes
+
+# Build and run
 go build -o chat .
 go build -o chat .
 sudo ./chat  # Needs root for ports 80/443/53/22
 sudo ./chat  # Needs root for ports 80/443/53/22
+```
+
+### Testing
 
 
-# Optional: build selftest tool
+```bash
+# Build the self-test tool
 go build -o selftest ./cmd/selftest
 go build -o selftest ./cmd/selftest
+
+# Run all protocol tests
+./selftest http://localhost
+
+# Test specific queries
+curl localhost/what-is-go
+curl localhost/?q=hello
 ```
 ```
 
 
-To run on high ports, edit the constants in `chat.go` and rebuild:
+### High Port Configuration
+
+To run without sudo, edit the constants in `chat.go`:
+
 ```go
 ```go
 const (
 const (
     HTTP_PORT   = 8080  // Instead of 80
     HTTP_PORT   = 8080  // Instead of 80
     HTTPS_PORT  = 0     // Disabled
     HTTPS_PORT  = 0     // Disabled
     SSH_PORT    = 2222  // Instead of 22
     SSH_PORT    = 2222  // Instead of 22
     DNS_PORT    = 0     // Disabled
     DNS_PORT    = 0     // Disabled
-    OPENAI_PORT = 8080  // Same as HTTP
 )
 )
 ```
 ```
 
 
-Then:
+Then build:
 ```bash
 ```bash
 go build -o chat .
 go build -o chat .
 ./chat  # No sudo needed for high ports
 ./chat  # No sudo needed for high ports
@@ -120,7 +104,7 @@ go build -o chat .
 Edit constants in source files:
 Edit constants in source files:
 - Ports: `chat.go` (set to 0 to disable)
 - Ports: `chat.go` (set to 0 to disable)
 - Rate limits: `util.go`
 - Rate limits: `util.go`
-- Remove protocol: Delete its .go file
+- Remove service: Delete its .go file
 
 
 ## Limitations
 ## Limitations
 
 

+ 9 - 45
chat.go

@@ -8,26 +8,18 @@ import (
 // Configuration - edit source code and recompile to change settings
 // Configuration - edit source code and recompile to change settings
 // To disable a service: set its port to 0 or delete its .go file
 // To disable a service: set its port to 0 or delete its .go file
 const (
 const (
-	HTTP_PORT   = 80    // Web interface (set to 0 to disable)
-	HTTPS_PORT  = 443   // TLS web interface (set to 0 to disable)
-	SSH_PORT    = 22    // Anonymous SSH chat (set to 0 to disable)
-	DNS_PORT    = 53    // DNS TXT chat (set to 0 to disable)
-	OPENAI_PORT = 8080  // OpenAI-compatible API (set to 0 to disable)
-	CERT_FILE   = "cert.pem" // TLS certificate for HTTPS
-	KEY_FILE    = "key.pem"  // TLS key for HTTPS
+	HTTP_PORT  = 80  // Web interface (set to 0 to disable)
+	HTTPS_PORT = 443 // TLS web interface (set to 0 to disable)
+	SSH_PORT   = 22  // Anonymous SSH chat (set to 0 to disable)
+	DNS_PORT   = 53  // DNS TXT chat (set to 0 to disable)
 )
 )
 
 
+
 func main() {
 func main() {
 	// SSH Server
 	// SSH Server
 	if SSH_PORT > 0 {
 	if SSH_PORT > 0 {
 		go func() {
 		go func() {
-			defer func() {
-				if r := recover(); r != nil {
-					fmt.Fprintf(os.Stderr, "SSH server panic: %v\n", r)
-				}
-			}()
-			sshServer := NewSSHServer(SSH_PORT)
-			if err := sshServer.Start(); err != nil {
+			if err := StartSSHServer(SSH_PORT); err != nil {
 				fmt.Fprintf(os.Stderr, "SSH server error: %v\n", err)
 				fmt.Fprintf(os.Stderr, "SSH server error: %v\n", err)
 			}
 			}
 		}()
 		}()
@@ -36,53 +28,25 @@ func main() {
 	// DNS Server
 	// DNS Server
 	if DNS_PORT > 0 {
 	if DNS_PORT > 0 {
 		go func() {
 		go func() {
-			defer func() {
-				if r := recover(); r != nil {
-					fmt.Fprintf(os.Stderr, "DNS server panic: %v\n", r)
-				}
-			}()
-			dnsServer := NewDNSServer(DNS_PORT)
-			if err := dnsServer.Start(); err != nil {
+			if err := StartDNSServer(DNS_PORT); err != nil {
 				fmt.Fprintf(os.Stderr, "DNS server error: %v\n", err)
 				fmt.Fprintf(os.Stderr, "DNS server error: %v\n", err)
 			}
 			}
 		}()
 		}()
 	}
 	}
 	
 	
-	// OpenAI API Server
-	if OPENAI_PORT > 0 {
-		go func() {
-			defer func() {
-				if r := recover(); r != nil {
-					fmt.Fprintf(os.Stderr, "OpenAI API server panic: %v\n", r)
-				}
-			}()
-			openaiServer := NewOpenAIServer(OPENAI_PORT)
-			if err := openaiServer.Start(); err != nil {
-				fmt.Fprintf(os.Stderr, "OpenAI API server error: %v\n", err)
-			}
-		}()
-	}
-	
 	// HTTP/HTTPS Server
 	// HTTP/HTTPS Server
 	// TODO: Implement graceful shutdown with signal handling
 	// TODO: Implement graceful shutdown with signal handling
 	if HTTP_PORT > 0 || HTTPS_PORT > 0 {
 	if HTTP_PORT > 0 || HTTPS_PORT > 0 {
-		httpServer := NewHTTPServer(HTTP_PORT)
-		
 		if HTTPS_PORT > 0 {
 		if HTTPS_PORT > 0 {
 			go func() {
 			go func() {
-				defer func() {
-					if r := recover(); r != nil {
-						fmt.Fprintf(os.Stderr, "HTTPS server panic: %v\n", r)
-					}
-				}()
-				if err := httpServer.StartTLS(HTTPS_PORT, CERT_FILE, KEY_FILE); err != nil {
+				if err := StartHTTPSServer(HTTPS_PORT, "cert.pem", "key.pem"); err != nil {
 					fmt.Fprintf(os.Stderr, "HTTPS server error: %v\n", err)
 					fmt.Fprintf(os.Stderr, "HTTPS server error: %v\n", err)
 				}
 				}
 			}()
 			}()
 		}
 		}
 		
 		
 		if HTTP_PORT > 0 {
 		if HTTP_PORT > 0 {
-			if err := httpServer.Start(); err != nil {
+			if err := StartHTTPServer(HTTP_PORT); err != nil {
 				fmt.Fprintf(os.Stderr, "HTTP server error: %v\n", err)
 				fmt.Fprintf(os.Stderr, "HTTP server error: %v\n", err)
 				os.Exit(1)
 				os.Exit(1)
 			}
 			}

+ 277 - 41
cmd/selftest/main.go

@@ -6,10 +6,87 @@ import (
 	"fmt"
 	"fmt"
 	"io"
 	"io"
 	"net/http"
 	"net/http"
+	"net/url"
 	"os"
 	"os"
+	"os/exec"
 	"strings"
 	"strings"
+	"time"
+
+	"golang.org/x/crypto/ssh"
 )
 )
 
 
+
+// extractLLMResponse extracts just the LLM response from various formats
+func extractLLMResponse(body string, contentType string) string {
+	body = strings.TrimSpace(body)
+	
+	// For error responses, return empty to fail the test
+	if strings.Contains(body, "error") || strings.Contains(body, "Error") {
+		return ""
+	}
+	
+	// For JSON responses
+	if strings.Contains(contentType, "json") {
+		var result map[string]string
+		if err := json.Unmarshal([]byte(body), &result); err == nil {
+			if answer, ok := result["answer"]; ok {
+				return strings.TrimSpace(answer)
+			}
+		}
+		return ""
+	}
+	
+	// For plain text Q&A format, extract just the answer
+	if strings.Contains(body, "\nA: ") {
+		lines := strings.Split(body, "\n")
+		for _, line := range lines {
+			if strings.HasPrefix(line, "A: ") {
+				return strings.TrimSpace(strings.TrimPrefix(line, "A: "))
+			}
+		}
+	}
+	
+	// Otherwise return trimmed body
+	return body
+}
+
+func checkResponse(resp *http.Response, err error, passed, failed *int) {
+	if err != nil {
+		fmt.Println("✗ (request failed)")
+		*failed++
+		return
+	}
+	defer resp.Body.Close()
+	
+	body, _ := io.ReadAll(resp.Body)
+	
+	if resp.StatusCode != 200 {
+		fmt.Printf("✗ (status %d)\n", resp.StatusCode)
+		*failed++
+		return
+	}
+	
+	contentType := resp.Header.Get("Content-Type")
+	llmResponse := extractLLMResponse(string(body), contentType)
+	
+	// Check for exact match "pass"
+	if llmResponse == "pass" {
+		fmt.Println("✓")
+		*passed++
+	} else {
+		// Show what we got instead
+		preview := llmResponse
+		if preview == "" {
+			preview = "error or empty response"
+		} else if len(preview) > 50 {
+			preview = preview[:50] + "..."
+		}
+		fmt.Printf("✗ (expected 'pass', got: %q)\n", preview)
+		*failed++
+	}
+}
+
+
 func main() {
 func main() {
 	if len(os.Args) < 2 {
 	if len(os.Args) < 2 {
 		fmt.Println("Usage: selftest <base-url>")
 		fmt.Println("Usage: selftest <base-url>")
@@ -18,59 +95,58 @@ func main() {
 	}
 	}
 
 
 	baseURL := strings.TrimSuffix(os.Args[1], "/")
 	baseURL := strings.TrimSuffix(os.Args[1], "/")
+	
+	// Extract hostname from URL for SSH/DNS tests
+	hostname := "localhost"
+	if u, err := url.Parse(baseURL); err == nil && u.Hostname() != "" {
+		hostname = u.Hostname()
+	}
+	
 	passed := 0
 	passed := 0
 	failed := 0
 	failed := 0
 
 
+	// Add delay between tests to avoid rate limiting
+	testDelay := 700 * time.Millisecond
+
 	// Test 1: Basic HTTP GET
 	// Test 1: Basic HTTP GET
 	fmt.Print("Testing HTTP GET... ")
 	fmt.Print("Testing HTTP GET... ")
-	resp, err := http.Get(baseURL + "/?q=hello")
-	if err == nil && resp.StatusCode == 200 {
-		body, _ := io.ReadAll(resp.Body)
-		resp.Body.Close()
-		if strings.Contains(string(body), "hello") || strings.Contains(string(body), "Hello") {
-			fmt.Println("✓")
-			passed++
-		} else {
-			fmt.Println("✗ (unexpected response)")
-			failed++
-		}
-	} else {
-		fmt.Println("✗ (request failed)")
-		failed++
-	}
+	resp, err := http.Get(baseURL + "/?q=repeat+verbatim+the+word+pass")
+	checkResponse(resp, err, &passed, &failed)
+
+	time.Sleep(testDelay)
 
 
 	// Test 2: HTTP POST
 	// Test 2: HTTP POST
 	fmt.Print("Testing HTTP POST... ")
 	fmt.Print("Testing HTTP POST... ")
-	resp, err = http.Post(baseURL+"/", "text/plain", strings.NewReader("What is 2+2?"))
-	if err == nil && resp.StatusCode == 200 {
-		body, _ := io.ReadAll(resp.Body)
-		resp.Body.Close()
-		if strings.Contains(string(body), "4") || strings.Contains(string(body), "four") {
-			fmt.Println("✓")
-			passed++
-		} else {
-			fmt.Println("✗ (unexpected response)")
-			failed++
-		}
-	} else {
-		fmt.Println("✗ (request failed)")
-		failed++
-	}
+	resp, err = http.Post(baseURL+"/", "text/plain", strings.NewReader("repeat verbatim the word pass"))
+	checkResponse(resp, err, &passed, &failed)
 
 
-	// Test 3: JSON API
+	time.Sleep(testDelay)
+
+	// Test 3: Path-based query
+	fmt.Print("Testing path-based query... ")
+	resp, err = http.Get(baseURL + "/repeat-verbatim-the-word-pass")
+	checkResponse(resp, err, &passed, &failed)
+
+	time.Sleep(testDelay)
+
+	// Test 4: JSON API
 	fmt.Print("Testing JSON API... ")
 	fmt.Print("Testing JSON API... ")
-	req, _ := http.NewRequest("GET", baseURL+"/?q=test", nil)
+	req, _ := http.NewRequest("GET", baseURL+"/?q=repeat+verbatim+the+word+pass", nil)
 	req.Header.Set("Accept", "application/json")
 	req.Header.Set("Accept", "application/json")
 	resp, err = http.DefaultClient.Do(req)
 	resp, err = http.DefaultClient.Do(req)
 	if err == nil && resp.StatusCode == 200 {
 	if err == nil && resp.StatusCode == 200 {
 		var result map[string]string
 		var result map[string]string
 		json.NewDecoder(resp.Body).Decode(&result)
 		json.NewDecoder(resp.Body).Decode(&result)
 		resp.Body.Close()
 		resp.Body.Close()
-		if result["question"] == "test" && result["answer"] != "" {
+		if result["question"] == "repeat verbatim the word pass" && result["answer"] == "pass" {
 			fmt.Println("✓")
 			fmt.Println("✓")
 			passed++
 			passed++
 		} else {
 		} else {
-			fmt.Println("✗ (invalid JSON response)")
+			answer := result["answer"]
+			if answer == "" {
+				answer = "no answer field"
+			}
+			fmt.Printf("✗ (expected 'pass', got: %q)\n", answer)
 			failed++
 			failed++
 		}
 		}
 	} else {
 	} else {
@@ -78,25 +154,48 @@ func main() {
 		failed++
 		failed++
 	}
 	}
 
 
-	// Test 4: OpenAI API compatibility
+	time.Sleep(testDelay)
+
+	// Test 5: OpenAI API compatibility
 	fmt.Print("Testing OpenAI API... ")
 	fmt.Print("Testing OpenAI API... ")
 	payload := map[string]interface{}{
 	payload := map[string]interface{}{
 		"model": "gpt-4o",
 		"model": "gpt-4o",
 		"messages": []map[string]string{
 		"messages": []map[string]string{
-			{"role": "user", "content": "Say 'test passed'"},
+			{"role": "user", "content": "repeat verbatim the word pass"},
 		},
 		},
 	}
 	}
 	jsonData, _ := json.Marshal(payload)
 	jsonData, _ := json.Marshal(payload)
-	// OpenAI API is on port 8080 in production
-	apiURL := "http://localhost:8080/v1/chat/completions"
+	// OpenAI API is on main HTTP port when OPENAI_PORT=0
+	apiURL := baseURL + "/v1/chat/completions"
 	resp, err = http.Post(apiURL, "application/json", bytes.NewReader(jsonData))
 	resp, err = http.Post(apiURL, "application/json", bytes.NewReader(jsonData))
 	if err == nil && resp.StatusCode == 200 {
 	if err == nil && resp.StatusCode == 200 {
 		var result map[string]interface{}
 		var result map[string]interface{}
 		json.NewDecoder(resp.Body).Decode(&result)
 		json.NewDecoder(resp.Body).Decode(&result)
 		resp.Body.Close()
 		resp.Body.Close()
 		if choices, ok := result["choices"].([]interface{}); ok && len(choices) > 0 {
 		if choices, ok := result["choices"].([]interface{}); ok && len(choices) > 0 {
-			fmt.Println("✓")
-			passed++
+			if choice, ok := choices[0].(map[string]interface{}); ok {
+				if message, ok := choice["message"].(map[string]interface{}); ok {
+					if content, ok := message["content"].(string); ok {
+						content = strings.TrimSpace(content)
+						if content == "pass" {
+							fmt.Println("✓")
+							passed++
+						} else {
+							fmt.Printf("✗ (expected 'pass', got: %q)\n", content)
+							failed++
+						}
+					} else {
+						fmt.Println("✗ (no content in message)")
+						failed++
+					}
+				} else {
+					fmt.Println("✗ (invalid message format)")
+					failed++
+				}
+			} else {
+				fmt.Println("✗ (invalid choice format)")
+				failed++
+			}
 		} else {
 		} else {
 			fmt.Println("✗ (invalid response format)")
 			fmt.Println("✗ (invalid response format)")
 			failed++
 			failed++
@@ -106,7 +205,144 @@ func main() {
 		failed++
 		failed++
 	}
 	}
 
 
-	// Test 5: Rate limiting (default is 100 requests/minute)
+	time.Sleep(testDelay)
+
+	// Test 6: SSH protocol
+	fmt.Print("Testing SSH protocol... ")
+	config := &ssh.ClientConfig{
+		User: "anonymous", 
+		Auth: []ssh.AuthMethod{
+			ssh.Password(""),
+		},
+		HostKeyCallback: ssh.InsecureIgnoreHostKey(),
+		Timeout:         5 * time.Second,
+		ClientVersion:   "SSH-2.0-Go", // Explicitly set version
+	}
+	
+	// Use 127.0.0.1 instead of localhost to avoid IPv6 issues
+	sshHost := hostname
+	if hostname == "localhost" {
+		sshHost = "127.0.0.1"
+	}
+	
+	sshClient, err := ssh.Dial("tcp", sshHost+":22", config)
+	if err == nil {
+		defer sshClient.Close()
+		
+		// Create a session and send a real query
+		session, err := sshClient.NewSession()
+		if err == nil {
+			defer session.Close()
+			
+			// Request PTY to simulate real terminal
+			if err := session.RequestPty("xterm", 80, 40, ssh.TerminalModes{}); err == nil {
+				// Set up pipes for input/output
+				stdin, _ := session.StdinPipe()
+				stdout, _ := session.StdoutPipe()
+				
+				// Start shell
+				if err := session.Shell(); err == nil {
+					// Send query
+					stdin.Write([]byte("repeat verbatim the word pass\n"))
+					stdin.Close()
+					
+					// Read response (with timeout)
+					done := make(chan bool)
+					var output []byte
+					go func() {
+						output, _ = io.ReadAll(stdout)
+						done <- true
+					}()
+					
+					select {
+					case <-done:
+						outputStr := string(output)
+						// Extract just the LLM response from SSH output
+						// Look for lines after our query
+						lines := strings.Split(outputStr, "\n")
+						llmResponse := ""
+						for i, line := range lines {
+							line = strings.TrimSpace(strings.TrimSuffix(line, "\r"))
+							// Find the line containing our query
+							if strings.Contains(line, "repeat verbatim the word pass") && i+1 < len(lines) {
+								// The response should be on the next line
+								nextLine := strings.TrimSpace(strings.TrimSuffix(lines[i+1], "\r"))
+								// Skip if it's a prompt line
+								if nextLine != "" && !strings.HasPrefix(nextLine, ">") {
+									llmResponse = nextLine
+									break
+								}
+							}
+						}
+						
+						// Check for response
+						if llmResponse == "pass" {
+							fmt.Println("✓")
+							passed++
+						} else {
+							if llmResponse == "" {
+								llmResponse = "no response extracted"
+							}
+							fmt.Printf("✗ (expected 'pass', got: %q)\n", llmResponse)
+							failed++
+						}
+					case <-time.After(3 * time.Second):
+						fmt.Println("✗ (SSH timeout)")
+						failed++
+					}
+				} else {
+					fmt.Println("✗ (SSH shell failed)")
+					failed++
+				}
+			} else {
+				fmt.Println("✗ (SSH PTY failed)")
+				failed++
+			}
+		} else {
+			fmt.Println("✗ (SSH session failed)")
+			failed++
+		}
+	} else {
+		// Try to understand the error
+		if strings.Contains(err.Error(), "handshake failed") {
+			fmt.Println("✗ (SSH handshake failed - server may require different auth)")
+		} else {
+			fmt.Printf("✗ (SSH failed: %v)\n", err)
+		}
+		failed++
+	}
+
+	time.Sleep(testDelay)
+
+	// Test 7: DNS protocol
+	fmt.Print("Testing DNS protocol... ")
+	// Run dig command to query the DNS server
+	cmd := exec.Command("dig", "+short", "@127.0.0.1", "-p", "53", "repeat-verbatim-the-word-pass.ch.at", "TXT")
+	output, err := cmd.Output()
+	if err != nil {
+		fmt.Printf("✗ (dig command failed: %v)\n", err)
+		failed++
+	} else {
+		outputStr := strings.TrimSpace(string(output))
+		// DNS TXT records come with quotes, remove them
+		outputStr = strings.Trim(outputStr, "\"")
+		
+		// Check for response
+		if outputStr == "pass" {
+			fmt.Println("✓")
+			passed++
+		} else {
+			if outputStr == "" {
+				outputStr = "empty response"
+			}
+			fmt.Printf("✗ (expected 'pass', got: %q)\n", outputStr)
+			failed++
+		}
+	}
+
+	time.Sleep(testDelay)
+
+	// Test 8: Rate limiting (default is 100 requests/minute)
 	fmt.Print("Testing rate limiting... ")
 	fmt.Print("Testing rate limiting... ")
 	rateLimitHit := false
 	rateLimitHit := false
 	// Make requests quickly to trigger rate limit
 	// Make requests quickly to trigger rate limit

+ 44 - 181
dns.go

@@ -1,209 +1,72 @@
 package main
 package main
 
 
 import (
 import (
-	"context"
 	"fmt"
 	"fmt"
-	"net"
 	"strings"
 	"strings"
+
+	"github.com/miekg/dns"
 )
 )
 
 
-type DNSServer struct {
-	port int
-}
+func StartDNSServer(port int) error {
+	// Set up DNS handler
+	dns.HandleFunc("ch.at.", handleDNS)
+	dns.HandleFunc(".", handleDNS) // Catch-all for any domain
 
 
-func NewDNSServer(port int) *DNSServer {
-	return &DNSServer{port: port}
-}
-
-func (s *DNSServer) Start() error {
-	addr, err := net.ResolveUDPAddr("udp", fmt.Sprintf(":%d", s.port))
-	if err != nil {
-		return err
+	// Create and start server
+	server := &dns.Server{
+		Addr: fmt.Sprintf(":%d", port),
+		Net:  "udp",
 	}
 	}
 
 
-	conn, err := net.ListenUDP("udp", addr)
-	if err != nil {
-		return err
-	}
-	defer conn.Close()
-
-	fmt.Printf("DNS server listening on :%d\n", s.port)
-
-	buf := make([]byte, 512) // DNS messages are typically small
-	for {
-		n, clientAddr, err := conn.ReadFromUDP(buf)
-		if err != nil {
-			// Read error - continue
-			continue
-		}
-
-		go s.handleQuery(conn, clientAddr, buf[:n])
-	}
+	fmt.Printf("DNS server listening on :%d\n", port)
+	return server.ListenAndServe()
 }
 }
 
 
-func (s *DNSServer) handleQuery(conn *net.UDPConn, addr *net.UDPAddr, query []byte) {
-	// Validate minimum DNS packet size
-	if len(query) < 12 {
-		return
-	}
-
+func handleDNS(w dns.ResponseWriter, r *dns.Msg) {
 	// Rate limiting
 	// Rate limiting
-	if !rateLimiter.Allow(addr.String()) {
+	if !rateLimitAllow(w.RemoteAddr().String()) {
 		return // Silently drop - DNS doesn't have error responses for rate limits
 		return // Silently drop - DNS doesn't have error responses for rate limits
 	}
 	}
 
 
-	// Validate DNS header flags (must be a query, not response)
-	if query[2]&0x80 != 0 {
-		return // It's a response, not a query
-	}
-
-	// Extract question from query
-	question := extractQuestion(query)
-	if question == "" {
+	// Check if we have a question
+	if len(r.Question) == 0 {
 		return
 		return
 	}
 	}
 
 
-	// Remove .ch.at suffix if present
-	question = strings.TrimSuffix(question, ".ch.at")
-	question = strings.TrimSuffix(question, ".")
-	
-	// Convert DNS format to readable (replace - with space)
-	prompt := strings.ReplaceAll(question, "-", " ")
+	// Build response
+	m := new(dns.Msg)
+	m.SetReply(r)
+	m.Authoritative = true
 
 
-	// Get LLM response
-	ctx := context.Background()
-	response, err := getLLMResponse(ctx, prompt)
-	if err != nil {
-		response = "Error: " + err.Error()
-	}
+	// Process each question (usually just one)
+	for _, q := range r.Question {
+		if q.Qtype != dns.TypeTXT {
+			continue // Only handle TXT queries
+		}
 
 
-	// Build DNS response with chunked TXT records
-	reply := buildDNSResponse(query, response)
-	
-	// Ensure response fits in UDP packet (RFC recommends 512 bytes)
-	if len(reply) > 512 {
-		// Truncate and set TC bit
-		reply = reply[:512]
-		reply[2] |= 0x02 // Set TC (truncation) bit
-	}
-	
-	conn.WriteToUDP(reply, addr)
-}
+		// Extract the prompt from domain name
+		name := strings.TrimSuffix(strings.TrimSuffix(q.Name, "."), ".ch.at")
+		prompt := strings.ReplaceAll(name, "-", " ")
 
 
-func extractQuestion(query []byte) string {
-	// Skip header (12 bytes)
-	if len(query) < 12 {
-		return ""
-	}
-	
-	pos := 12
-	var name []string
-	totalLength := 0
-	
-	// Parse domain name labels (max 128 to prevent DoS)
-	for i := 0; i < 128 && pos < len(query); i++ {
-		if pos >= len(query) {
-			return ""
-		}
-		
-		length := int(query[pos])
-		if length == 0 {
-			break
-		}
-		
-		// DNS compression uses first 2 bits = 11 (0xC0)
-		// We reject these for simplicity and security
-		if length&0xC0 == 0xC0 {
-			return ""
-		}
-		
-		// DNS label length must be <= 63
-		if length > 63 {
-			return ""
-		}
-		
-		pos++
-		if pos+length > len(query) {
-			return ""
-		}
-		
-		// Track total domain name length (max 255)
-		totalLength += length + 1
-		if totalLength > 255 {
-			return ""
+		// Get LLM response
+		response, err := LLM(prompt, nil)
+		if err != nil {
+			response = err.Error()
 		}
 		}
-		
-		// Validate label contains reasonable characters
-		label := query[pos : pos+length]
-		name = append(name, string(label))
-		pos += length
-	}
-	
-	// Ensure we read a complete question (should have type and class after)
-	if pos+4 > len(query) {
-		return ""
-	}
-	
-	return strings.Join(name, ".")
-}
 
 
-func buildDNSResponse(query []byte, answer string) []byte {
-	resp := make([]byte, len(query))
-	copy(resp, query)
-	
-	// Set response flags (QR=1, AA=1)
-	resp[2] = 0x81
-	resp[3] = 0x80
-	
-	// Set answer count to 1
-	resp[7] = 1
-	
-	// Skip to end of question section
-	pos := 12
-	for pos < len(resp) {
-		if resp[pos] == 0 {
-			pos += 5 // Skip null terminator + type + class
-			break
+		// Create TXT record
+		txt := &dns.TXT{
+			Hdr: dns.RR_Header{
+				Name:   q.Name,
+				Rrtype: dns.TypeTXT,
+				Class:  dns.ClassINET,
+				Ttl:    60,
+			},
+			Txt: []string{response},
 		}
 		}
-		pos++
+		m.Answer = append(m.Answer, txt)
 	}
 	}
-	
-	// Add answer section
-	// Pointer to question name
-	resp = append(resp, 0xc0, 0x0c)
-	
-	// Type TXT (16), Class IN (1)
-	resp = append(resp, 0x00, 0x10, 0x00, 0x01)
-	
-	// TTL (0)
-	resp = append(resp, 0x00, 0x00, 0x00, 0x00)
-	
-	// Build TXT record data with chunking
-	txtData := buildTXTData(answer)
-	
-	// Data length
-	resp = append(resp, byte(len(txtData)>>8), byte(len(txtData)))
-	
-	// TXT data
-	resp = append(resp, txtData...)
-	
-	return resp
-}
 
 
-func buildTXTData(text string) []byte {
-	var data []byte
-	
-	// Split into 255-byte chunks
-	for len(text) > 0 {
-		chunkLen := len(text)
-		if chunkLen > 255 {
-			chunkLen = 255
-		}
-		
-		data = append(data, byte(chunkLen))
-		data = append(data, text[:chunkLen]...)
-		text = text[chunkLen:]
-	}
-	
-	return data
+	// Send response
+	w.WriteMsg(m)
 }
 }

+ 15 - 3
go.mod

@@ -1,7 +1,19 @@
 module ch.at
 module ch.at
 
 
-go 1.21
+go 1.23.0
 
 
-require golang.org/x/crypto v0.28.0
+toolchain go1.24.3
 
 
-require golang.org/x/sys v0.26.0 // indirect
+require (
+	github.com/miekg/dns v1.1.66
+	golang.org/x/crypto v0.37.0
+	golang.org/x/time v0.12.0
+)
+
+require (
+	golang.org/x/mod v0.24.0 // indirect
+	golang.org/x/net v0.39.0 // indirect
+	golang.org/x/sync v0.13.0 // indirect
+	golang.org/x/sys v0.32.0 // indirect
+	golang.org/x/tools v0.32.0 // indirect
+)

+ 20 - 6
go.sum

@@ -1,6 +1,20 @@
-golang.org/x/crypto v0.28.0 h1:GBDwsMXVQi34v5CCYUm2jkJvu4cbtru2U4TN2PSyQnw=
-golang.org/x/crypto v0.28.0/go.mod h1:rmgy+3RHxRZMyY0jjAJShp2zgEdOqj2AO7U0pYmeQ7U=
-golang.org/x/sys v0.26.0 h1:KHjCJyddX0LoSTb3J+vWpupP9p0oznkqVk/IfjymZbo=
-golang.org/x/sys v0.26.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA=
-golang.org/x/term v0.25.0 h1:WtHI/ltw4NvSUig5KARz9h521QvRC8RmF/cuYqifU24=
-golang.org/x/term v0.25.0/go.mod h1:RPyXicDX+6vLxogjjRxjgD2TKtmAO6NZBsBRfrOLu7M=
+github.com/google/go-cmp v0.6.0 h1:ofyhxvXcZhMsU5ulbFiLKl/XBFqE1GSq7atu8tAmTRI=
+github.com/google/go-cmp v0.6.0/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY=
+github.com/miekg/dns v1.1.66 h1:FeZXOS3VCVsKnEAd+wBkjMC3D2K+ww66Cq3VnCINuJE=
+github.com/miekg/dns v1.1.66/go.mod h1:jGFzBsSNbJw6z1HYut1RKBKHA9PBdxeHrZG8J+gC2WE=
+golang.org/x/crypto v0.37.0 h1:kJNSjF/Xp7kU0iB2Z+9viTPMW4EqqsrywMXLJOOsXSE=
+golang.org/x/crypto v0.37.0/go.mod h1:vg+k43peMZ0pUMhYmVAWysMK35e6ioLh3wB8ZCAfbVc=
+golang.org/x/mod v0.24.0 h1:ZfthKaKaT4NrhGVZHO1/WDTwGES4De8KtWO0SIbNJMU=
+golang.org/x/mod v0.24.0/go.mod h1:IXM97Txy2VM4PJ3gI61r1YEk/gAj6zAHN3AdZt6S9Ww=
+golang.org/x/net v0.39.0 h1:ZCu7HMWDxpXpaiKdhzIfaltL9Lp31x/3fCP11bc6/fY=
+golang.org/x/net v0.39.0/go.mod h1:X7NRbYVEA+ewNkCNyJ513WmMdQ3BineSwVtN2zD/d+E=
+golang.org/x/sync v0.13.0 h1:AauUjRAJ9OSnvULf/ARrrVywoJDy0YS2AwQ98I37610=
+golang.org/x/sync v0.13.0/go.mod h1:1dzgHSNfp02xaA81J2MS99Qcpr2w7fw1gpm99rleRqA=
+golang.org/x/sys v0.32.0 h1:s77OFDvIQeibCmezSnk/q6iAfkdiQaJi4VzroCFrN20=
+golang.org/x/sys v0.32.0/go.mod h1:BJP2sWEmIv4KK5OTEluFJCKSidICx8ciO85XgH3Ak8k=
+golang.org/x/term v0.31.0 h1:erwDkOK1Msy6offm1mOgvspSkslFnIGsFnxOKoufg3o=
+golang.org/x/term v0.31.0/go.mod h1:R4BeIy7D95HzImkxGkTW1UQTtP54tio2RyHz7PwK0aw=
+golang.org/x/time v0.12.0 h1:ScB/8o8olJvc+CQPWrK3fPZNfh7qgwCrY0zJmoEQLSE=
+golang.org/x/time v0.12.0/go.mod h1:CDIdPxbZBQxdj6cxyCIdrNogrJKMJ7pr37NYpMcMDSg=
+golang.org/x/tools v0.32.0 h1:Q7N1vhpkQv7ybVzLFtTjvQya2ewbwNDZzUgfXGqtMWU=
+golang.org/x/tools v0.32.0/go.mod h1:ZxrU41P/wAbZD8EDa6dDCa6XfpkhJ7HFMjHJXfBDu8s=

+ 151 - 38
http.go

@@ -1,7 +1,6 @@
 package main
 package main
 
 
 import (
 import (
-	"context"
 	"encoding/json"
 	"encoding/json"
 	"fmt"
 	"fmt"
 	"html"
 	"html"
@@ -35,31 +34,24 @@ const minimalHTML = `<!DOCTYPE html>
 </body>
 </body>
 </html>`
 </html>`
 
 
-type HTTPServer struct {
-	port int
-}
-
-func NewHTTPServer(port int) *HTTPServer {
-	return &HTTPServer{port: port}
-}
-
-func (s *HTTPServer) Start() error {
-	http.HandleFunc("/", s.handleRoot)
+func StartHTTPServer(port int) error {
+	http.HandleFunc("/", handleRoot)
+	http.HandleFunc("/v1/chat/completions", handleChatCompletions)
 
 
-	addr := fmt.Sprintf(":%d", s.port)
+	addr := fmt.Sprintf(":%d", port)
 	fmt.Printf("HTTP server listening on %s\n", addr)
 	fmt.Printf("HTTP server listening on %s\n", addr)
 	return http.ListenAndServe(addr, nil)
 	return http.ListenAndServe(addr, nil)
 }
 }
 
 
-func (s *HTTPServer) StartTLS(port int, certFile, keyFile string) error {
+func StartHTTPSServer(port int, certFile, keyFile string) error {
 	addr := fmt.Sprintf(":%d", port)
 	addr := fmt.Sprintf(":%d", port)
 	fmt.Printf("HTTPS server listening on %s\n", addr)
 	fmt.Printf("HTTPS server listening on %s\n", addr)
 	return http.ListenAndServeTLS(addr, certFile, keyFile, nil)
 	return http.ListenAndServeTLS(addr, certFile, keyFile, nil)
 }
 }
 
 
-func (s *HTTPServer) handleRoot(w http.ResponseWriter, r *http.Request) {
-	if !rateLimiter.Allow(r.RemoteAddr) {
-		http.Error(w, "Rate limit exceeded. Please try again later.", http.StatusTooManyRequests)
+func handleRoot(w http.ResponseWriter, r *http.Request) {
+	if !rateLimitAllow(r.RemoteAddr) {
+		http.Error(w, "Rate limit exceeded", http.StatusTooManyRequests)
 		return
 		return
 	}
 	}
 
 
@@ -90,12 +82,13 @@ func (s *HTTPServer) handleRoot(w http.ResponseWriter, r *http.Request) {
 			query = string(body)
 			query = string(body)
 		}
 		}
 	} else {
 	} else {
-		// GET request - no history
 		query = r.URL.Query().Get("q")
 		query = r.URL.Query().Get("q")
+		// Also support path-based queries like /what-is-go
+		if query == "" && r.URL.Path != "/" {
+			query = strings.ReplaceAll(strings.TrimPrefix(r.URL.Path, "/"), "-", " ")
+		}
 	}
 	}
 
 
-	ctx, cancel := context.WithTimeout(r.Context(), 30*time.Second)
-	defer cancel()
 
 
 	if query != "" {
 	if query != "" {
 		// Build prompt with history
 		// Build prompt with history
@@ -104,9 +97,9 @@ func (s *HTTPServer) handleRoot(w http.ResponseWriter, r *http.Request) {
 			prompt = history + "Q: " + query
 			prompt = history + "Q: " + query
 		}
 		}
 
 
-		response, err := getLLMResponse(ctx, prompt)
+		response, err := LLM(prompt, nil)
 		if err != nil {
 		if err != nil {
-			content = fmt.Sprintf("Error: %s", err.Error())
+			content = err.Error()
 			errJSON, _ := json.Marshal(map[string]string{"error": err.Error()})
 			errJSON, _ := json.Marshal(map[string]string{"error": err.Error()})
 			jsonResponse = string(errJSON)
 			jsonResponse = string(errJSON)
 		} else {
 		} else {
@@ -126,13 +119,10 @@ func (s *HTTPServer) handleRoot(w http.ResponseWriter, r *http.Request) {
 			}
 			}
 			// Trim history if too long (UTF-8 safe)
 			// Trim history if too long (UTF-8 safe)
 			if len(content) > 2048 {
 			if len(content) > 2048 {
-				// UTF-8 continuation bytes start with 10xxxxxx (0x80-0xBF)
-				// Find a character boundary to avoid splitting multi-byte chars
-				for i := len(content) - 2048; i < len(content)-2040; i++ {
-					if content[i]&0xC0 != 0x80 { // Not a continuation byte
-						content = content[i:]
-						break
-					}
+				// Keep roughly last 600 characters (UTF-8 safe)
+				runes := []rune(content)
+				if len(runes) > 600 {
+					content = string(runes[len(runes)-600:])
 				}
 				}
 			}
 			}
 		}
 		}
@@ -141,26 +131,33 @@ func (s *HTTPServer) handleRoot(w http.ResponseWriter, r *http.Request) {
 	}
 	}
 
 
 	accept := r.Header.Get("Accept")
 	accept := r.Header.Get("Accept")
+	wantsJSON := strings.Contains(accept, "application/json")
+	wantsHTML := strings.Contains(accept, "text/html")
+	wantsStream := strings.Contains(accept, "text/event-stream")
 
 
 	// Stream for curl when requested
 	// Stream for curl when requested
-	if strings.Contains(accept, "text/event-stream") && query != "" {
+	if wantsStream && query != "" {
 		w.Header().Set("Content-Type", "text/event-stream")
 		w.Header().Set("Content-Type", "text/event-stream")
 		w.Header().Set("Cache-Control", "no-cache")
 		w.Header().Set("Cache-Control", "no-cache")
 		w.Header().Set("Connection", "keep-alive")
 		w.Header().Set("Connection", "keep-alive")
-
+		
 		flusher, ok := w.(http.Flusher)
 		flusher, ok := w.(http.Flusher)
 		if !ok {
 		if !ok {
 			http.Error(w, "Streaming not supported", http.StatusInternalServerError)
 			http.Error(w, "Streaming not supported", http.StatusInternalServerError)
 			return
 			return
 		}
 		}
 
 
-		stream, err := getLLMResponseStream(ctx, prompt)
-		if err != nil {
-			fmt.Fprintf(w, "data: Error: %s\n\n", err.Error())
-			return
-		}
+		// Stream response
+		ch := make(chan string)
+		go func() {
+			if _, err := LLM(prompt, ch); err != nil {
+				// Send error as SSE event
+				fmt.Fprintf(w, "data: Error: %s\n\n", err.Error())
+				flusher.Flush()
+			}
+		}()
 
 
-		for chunk := range stream {
+		for chunk := range ch {
 			fmt.Fprintf(w, "data: %s\n\n", chunk)
 			fmt.Fprintf(w, "data: %s\n\n", chunk)
 			flusher.Flush()
 			flusher.Flush()
 		}
 		}
@@ -169,10 +166,10 @@ func (s *HTTPServer) handleRoot(w http.ResponseWriter, r *http.Request) {
 	}
 	}
 
 
 	// Return JSON for API requests, HTML for browsers, plain text for curl
 	// Return JSON for API requests, HTML for browsers, plain text for curl
-	if strings.Contains(accept, "application/json") && jsonResponse != "" {
+	if wantsJSON && jsonResponse != "" {
 		w.Header().Set("Content-Type", "application/json; charset=utf-8")
 		w.Header().Set("Content-Type", "application/json; charset=utf-8")
 		fmt.Fprint(w, jsonResponse)
 		fmt.Fprint(w, jsonResponse)
-	} else if strings.Contains(accept, "text/html") {
+	} else if wantsHTML {
 		w.Header().Set("Content-Type", "text/html; charset=utf-8")
 		w.Header().Set("Content-Type", "text/html; charset=utf-8")
 		fmt.Fprintf(w, minimalHTML, html.EscapeString(content), html.EscapeString(content))
 		fmt.Fprintf(w, minimalHTML, html.EscapeString(content), html.EscapeString(content))
 	} else {
 	} else {
@@ -181,3 +178,119 @@ func (s *HTTPServer) handleRoot(w http.ResponseWriter, r *http.Request) {
 		fmt.Fprint(w, content)
 		fmt.Fprint(w, content)
 	}
 	}
 }
 }
+
+type ChatRequest struct {
+	Model    string    `json:"model"`
+	Messages []Message `json:"messages"`
+	Stream   bool      `json:"stream,omitempty"`
+}
+
+type Message struct {
+	Role    string `json:"role"`
+	Content string `json:"content"`
+}
+
+type ChatResponse struct {
+	ID      string   `json:"id"`
+	Object  string   `json:"object"`
+	Created int64    `json:"created"`
+	Model   string   `json:"model"`
+	Choices []Choice `json:"choices"`
+}
+
+type Choice struct {
+	Index   int     `json:"index"`
+	Message Message `json:"message"`
+}
+
+func handleChatCompletions(w http.ResponseWriter, r *http.Request) {
+	if !rateLimitAllow(r.RemoteAddr) {
+		http.Error(w, "Rate limit exceeded", http.StatusTooManyRequests)
+		return
+	}
+	
+	if r.Method != "POST" {
+		http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
+		return
+	}
+
+	var req ChatRequest
+	if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
+		http.Error(w, "Invalid JSON", http.StatusBadRequest)
+		return
+	}
+
+	// Convert messages to format for LLM
+	messages := make([]map[string]string, len(req.Messages))
+	for i, msg := range req.Messages {
+		messages[i] = map[string]string{
+			"role":    msg.Role,
+			"content": msg.Content,
+		}
+	}
+	
+	
+	if req.Stream {
+		w.Header().Set("Content-Type", "text/event-stream")
+		w.Header().Set("Cache-Control", "no-cache")
+		w.Header().Set("Connection", "keep-alive")
+		
+		flusher, ok := w.(http.Flusher)
+		if !ok {
+			http.Error(w, "Streaming not supported", http.StatusInternalServerError)
+			return
+		}
+		
+		// Stream response
+		ch := make(chan string)
+		go LLM(messages, ch)
+		
+		for chunk := range ch {
+			resp := map[string]interface{}{
+				"id": fmt.Sprintf("chatcmpl-%d", time.Now().Unix()),
+				"object": "chat.completion.chunk",
+				"created": time.Now().Unix(),
+				"model": req.Model,
+				"choices": []map[string]interface{}{{
+					"index": 0,
+					"delta": map[string]string{"content": chunk},
+				}},
+			}
+			data, err := json.Marshal(resp)
+			if err != nil {
+				fmt.Fprintf(w, "data: Failed to marshal response\n\n")
+				return
+			}
+			fmt.Fprintf(w, "data: %s\n\n", data)
+			flusher.Flush()
+		}
+		fmt.Fprintf(w, "data: [DONE]\n\n")
+		
+	} else {
+		response, err := LLM(messages, nil)
+		if err != nil {
+			http.Error(w, err.Error(), http.StatusInternalServerError)
+			return
+		}
+
+		chatResp := ChatResponse{
+			ID:      fmt.Sprintf("chatcmpl-%d", time.Now().Unix()),
+			Object:  "chat.completion",
+			Created: time.Now().Unix(),
+			Model:   req.Model,
+			Choices: []Choice{{
+				Index: 0,
+				Message: Message{
+					Role:    "assistant",
+					Content: response,
+				},
+			}},
+		}
+
+		w.Header().Set("Content-Type", "application/json")
+		json.NewEncoder(w).Encode(chatResp)
+	}
+}
+
+
+

+ 138 - 0
llm.go.example

@@ -0,0 +1,138 @@
+package main
+
+import (
+	"bufio"
+	"bytes"
+	"encoding/json"
+	"fmt"
+	"io"
+	"net/http"
+	"strings"
+)
+
+// Configuration - Replace with your API credentials
+const (
+	// Option 1: OpenAI API
+	apiKey    = "YOUR_API_KEY_HERE"
+	apiURL    = "https://api.openai.com/v1/chat/completions"
+	modelName = "gpt-3.5-turbo" // or gpt-4, etc.
+
+	// Option 2: Anthropic Claude API (uncomment to use)
+	// apiKey     = "YOUR_API_KEY_HERE"
+	// apiURL     = "https://api.anthropic.com/v1/messages"
+	// modelName  = "claude-3-haiku" // or claude-3-opus, etc.
+
+	// Option 3: Local LLM (uncomment to use)
+	// apiKey     = "" // No API key needed for local models
+	// apiURL     = "http://localhost:11434/api/chat" // Ollama example
+	// modelName  = "llama2" // or mixtral, phi, etc.
+)
+
+// LLM calls the language model. If stream is nil, returns complete response via return value.
+// If stream is provided, streams response chunks to channel and returns empty string.
+// Input can be a string (wrapped as user message) or []map[string]string for full message history.
+func LLM(input interface{}, stream chan<- string) (string, error) {
+	// Build messages array
+	var messages []map[string]string
+	switch v := input.(type) {
+	case string:
+		messages = []map[string]string{
+			{"role": "user", "content": v},
+		}
+	case []map[string]string:
+		messages = v
+	default:
+		return "", fmt.Errorf("invalid input type")
+	}
+	
+	// Build request
+	requestBody := map[string]interface{}{
+		"model":       modelName,
+		"messages":    messages,
+		"temperature": 0.7,
+		"max_tokens":  500,
+	}
+	
+	if stream != nil {
+		requestBody["stream"] = true
+		defer close(stream)
+	}
+
+	jsonBody, err := json.Marshal(requestBody)
+	if err != nil {
+		return "", err
+	}
+
+	req, err := http.NewRequest("POST", apiURL, bytes.NewBuffer(jsonBody))
+	if err != nil {
+		return "", err
+	}
+
+	req.Header.Set("Content-Type", "application/json")
+	if apiKey != "" {
+		req.Header.Set("Authorization", "Bearer "+apiKey)
+	}
+
+	client := &http.Client{}
+	resp, err := client.Do(req)
+	if err != nil {
+		return "", err
+	}
+	defer resp.Body.Close()
+
+	if resp.StatusCode != http.StatusOK {
+		body, _ := io.ReadAll(resp.Body)
+		return "", fmt.Errorf("API error (status %d): %s", resp.StatusCode, string(body))
+	}
+
+	// Handle streaming response
+	if stream != nil {
+		scanner := bufio.NewScanner(resp.Body)
+		for scanner.Scan() {
+			line := scanner.Text()
+			if strings.HasPrefix(line, "data: ") {
+				data := strings.TrimPrefix(line, "data: ")
+				if data == "[DONE]" {
+					return "", nil
+				}
+				
+				var chunk map[string]interface{}
+				if err := json.Unmarshal([]byte(data), &chunk); err == nil {
+					if choices, ok := chunk["choices"].([]interface{}); ok && len(choices) > 0 {
+						if choice, ok := choices[0].(map[string]interface{}); ok {
+							if delta, ok := choice["delta"].(map[string]interface{}); ok {
+								if content, ok := delta["content"].(string); ok {
+									stream <- content
+								}
+							}
+						}
+					}
+				}
+			}
+		}
+		return "", scanner.Err()
+	}
+
+	// Handle non-streaming response
+	body, err := io.ReadAll(resp.Body)
+	if err != nil {
+		return "", err
+	}
+
+	var response map[string]interface{}
+	if err := json.Unmarshal(body, &response); err != nil {
+		return "", err
+	}
+
+	if choices, ok := response["choices"].([]interface{}); ok && len(choices) > 0 {
+		if choice, ok := choices[0].(map[string]interface{}); ok {
+			if message, ok := choice["message"].(map[string]interface{}); ok {
+				if content, ok := message["content"].(string); ok {
+					return content, nil
+				}
+			}
+		}
+	}
+
+	return "", fmt.Errorf("unexpected response format")
+}

+ 0 - 147
openai.go

@@ -1,147 +0,0 @@
-package main
-
-import (
-	"context"
-	"encoding/json"
-	"fmt"
-	"net/http"
-	"strings"
-	"time"
-)
-
-type OpenAIServer struct {
-	port int
-}
-
-func NewOpenAIServer(port int) *OpenAIServer {
-	return &OpenAIServer{port: port}
-}
-
-func (s *OpenAIServer) Start() error {
-	mux := http.NewServeMux()
-	mux.HandleFunc("/v1/chat/completions", s.handleChatCompletions)
-	
-	addr := fmt.Sprintf(":%d", s.port)
-	fmt.Printf("OpenAI API server listening on %s\n", addr)
-	return http.ListenAndServe(addr, mux)
-}
-
-type ChatRequest struct {
-	Model    string    `json:"model"`
-	Messages []Message `json:"messages"`
-	Stream   bool      `json:"stream,omitempty"`
-}
-
-type Message struct {
-	Role    string `json:"role"`
-	Content string `json:"content"`
-}
-
-type ChatResponse struct {
-	ID      string   `json:"id"`
-	Object  string   `json:"object"`
-	Created int64    `json:"created"`
-	Model   string   `json:"model"`
-	Choices []Choice `json:"choices"`
-}
-
-type Choice struct {
-	Index   int     `json:"index"`
-	Message Message `json:"message"`
-}
-
-
-func (s *OpenAIServer) handleChatCompletions(w http.ResponseWriter, r *http.Request) {
-	if r.Method != "POST" {
-		http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
-		return
-	}
-
-	var req ChatRequest
-	if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
-		http.Error(w, "Invalid JSON", http.StatusBadRequest)
-		return
-	}
-
-	// Convert messages to single prompt
-	prompt := buildPrompt(req.Messages)
-	
-	// Call our chat function
-	ctx, cancel := context.WithTimeout(r.Context(), 30*time.Second)
-	defer cancel()
-	
-	if req.Stream {
-		// Streaming response
-		w.Header().Set("Content-Type", "text/event-stream")
-		w.Header().Set("Cache-Control", "no-cache")
-		w.Header().Set("Connection", "keep-alive")
-		
-		flusher, ok := w.(http.Flusher)
-		if !ok {
-			http.Error(w, "Streaming not supported", http.StatusInternalServerError)
-			return
-		}
-		
-		stream, err := getLLMResponseStream(ctx, prompt)
-		if err != nil {
-			fmt.Fprintf(w, "data: {\"error\": \"%s\"}\n\n", err.Error())
-			return
-		}
-		
-		for chunk := range stream {
-			resp := map[string]interface{}{
-				"id": fmt.Sprintf("chatcmpl-%d", time.Now().Unix()),
-				"object": "chat.completion.chunk",
-				"created": time.Now().Unix(),
-				"model": req.Model,
-				"choices": []map[string]interface{}{{
-					"index": 0,
-					"delta": map[string]string{"content": chunk},
-				}},
-			}
-			data, err := json.Marshal(resp)
-			if err != nil {
-				fmt.Fprintf(w, "data: error marshaling response\n\n")
-				return
-			}
-			fmt.Fprintf(w, "data: %s\n\n", data)
-			flusher.Flush()
-		}
-		fmt.Fprintf(w, "data: [DONE]\n\n")
-		
-	} else {
-		// Non-streaming response
-		response, err := getLLMResponse(ctx, prompt)
-		if err != nil {
-			http.Error(w, fmt.Sprintf("Chat error: %v", err), http.StatusInternalServerError)
-			return
-		}
-
-		// Return OpenAI-compatible response
-		chatResp := ChatResponse{
-			ID:      fmt.Sprintf("chatcmpl-%d", time.Now().Unix()),
-			Object:  "chat.completion",
-			Created: time.Now().Unix(),
-			Model:   req.Model,
-			Choices: []Choice{{
-				Index: 0,
-				Message: Message{
-					Role:    "assistant",
-					Content: response,
-				},
-			}},
-		}
-
-		w.Header().Set("Content-Type", "application/json")
-		json.NewEncoder(w).Encode(chatResp)
-	}
-}
-
-func buildPrompt(messages []Message) string {
-	// Simple: just concatenate messages
-	var parts []string
-	for _, msg := range messages {
-		parts = append(parts, msg.Content)
-	}
-	return strings.Join(parts, "\n")
-}

+ 23 - 57
ssh.go

@@ -1,28 +1,16 @@
 package main
 package main
 
 
 import (
 import (
-	"context"
 	"crypto/rand"
 	"crypto/rand"
 	"crypto/rsa"
 	"crypto/rsa"
-	"crypto/x509"
-	"encoding/pem"
 	"fmt"
 	"fmt"
 	"net"
 	"net"
-	"os"
 	"strings"
 	"strings"
 
 
 	"golang.org/x/crypto/ssh"
 	"golang.org/x/crypto/ssh"
 )
 )
 
 
-type SSHServer struct {
-	port int
-}
-
-func NewSSHServer(port int) *SSHServer {
-	return &SSHServer{port: port}
-}
-
-func (s *SSHServer) Start() error {
+func StartSSHServer(port int) error {
 	// SSH server configuration
 	// SSH server configuration
 	config := &ssh.ServerConfig{
 	config := &ssh.ServerConfig{
 		NoClientAuth: true, // Anonymous access
 		NoClientAuth: true, // Anonymous access
@@ -36,13 +24,13 @@ func (s *SSHServer) Start() error {
 	config.AddHostKey(privateKey)
 	config.AddHostKey(privateKey)
 
 
 	// Listen for connections
 	// Listen for connections
-	listener, err := net.Listen("tcp", fmt.Sprintf(":%d", s.port))
+	listener, err := net.Listen("tcp", fmt.Sprintf(":%d", port))
 	if err != nil {
 	if err != nil {
 		return err
 		return err
 	}
 	}
 	defer listener.Close()
 	defer listener.Close()
 
 
-	fmt.Printf("SSH server listening on :%d\n", s.port)
+	fmt.Printf("SSH server listening on :%d\n", port)
 
 
 	// Simple connection limiting
 	// Simple connection limiting
 	sem := make(chan struct{}, 100) // Max 100 concurrent SSH connections
 	sem := make(chan struct{}, 100) // Max 100 concurrent SSH connections
@@ -58,7 +46,7 @@ func (s *SSHServer) Start() error {
 		case sem <- struct{}{}:
 		case sem <- struct{}{}:
 			go func() {
 			go func() {
 				defer func() { <-sem }()
 				defer func() { <-sem }()
-				s.handleConnection(conn, config)
+				handleConnection(conn, config)
 			}()
 			}()
 		default:
 		default:
 			// Too many connections
 			// Too many connections
@@ -67,12 +55,12 @@ func (s *SSHServer) Start() error {
 	}
 	}
 }
 }
 
 
-func (s *SSHServer) handleConnection(netConn net.Conn, config *ssh.ServerConfig) {
+func handleConnection(netConn net.Conn, config *ssh.ServerConfig) {
 	defer netConn.Close()
 	defer netConn.Close()
 
 
 	// Rate limiting
 	// Rate limiting
-	if !rateLimiter.Allow(netConn.RemoteAddr().String()) {
-		netConn.Write([]byte("Rate limit exceeded. Please try again later.\r\n"))
+	if !rateLimitAllow(netConn.RemoteAddr().String()) {
+		netConn.Write([]byte("Rate limit exceeded\r\n"))
 		return
 		return
 	}
 	}
 
 
@@ -99,11 +87,11 @@ func (s *SSHServer) handleConnection(netConn net.Conn, config *ssh.ServerConfig)
 			continue
 			continue
 		}
 		}
 
 
-		go s.handleSession(channel, requests)
+		go handleSession(channel, requests)
 	}
 	}
 }
 }
 
 
-func (s *SSHServer) handleSession(channel ssh.Channel, requests <-chan *ssh.Request) {
+func handleSession(channel ssh.Channel, requests <-chan *ssh.Request) {
 	defer channel.Close()
 	defer channel.Close()
 
 
 	// Handle session requests
 	// Handle session requests
@@ -150,34 +138,29 @@ func (s *SSHServer) handleSession(channel ssh.Channel, requests <-chan *ssh.Requ
 					}
 					}
 					
 					
 					// Get LLM response with streaming
 					// Get LLM response with streaming
-					ctx := context.Background()
-					stream, err := getLLMResponseStream(ctx, query)
-					if err != nil {
-						fmt.Fprintf(channel, "Error: %v\r\n", err)
-						fmt.Fprintf(channel, "> ")
-						continue
-					}
+					ch := make(chan string)
+					go func() {
+						if _, err := LLM(query, ch); err != nil {
+							fmt.Fprintf(channel, "Error: %s\r\n", err.Error())
+						}
+					}()
 					
 					
 					// Stream response as it arrives
 					// Stream response as it arrives
-					for chunk := range stream {
+					for chunk := range ch {
 						fmt.Fprint(channel, chunk)
 						fmt.Fprint(channel, chunk)
-						if f, ok := channel.(interface{ Flush() }); ok {
-							f.Flush()
-						}
 					}
 					}
 					
 					
 					fmt.Fprintf(channel, "\r\n> ")
 					fmt.Fprintf(channel, "\r\n> ")
 				}
 				}
 			} else if ch == '\b' || ch == 127 { // Backspace or Delete
 			} else if ch == '\b' || ch == 127 { // Backspace or Delete
 				if input.Len() > 0 {
 				if input.Len() > 0 {
-					// Remove last character from buffer
+					// Remove last rune (UTF-8 safe)
 					str := input.String()
 					str := input.String()
+					runes := []rune(str)
 					input.Reset()
 					input.Reset()
-					if len(str) > 0 {
-						input.WriteString(str[:len(str)-1])
-						// Move cursor back, overwrite with space, move back again
-						fmt.Fprintf(channel, "\b \b")
-					}
+					input.WriteString(string(runes[:len(runes)-1]))
+					// Move cursor back, overwrite with space, move back again
+					fmt.Fprintf(channel, "\b \b")
 				}
 				}
 			} else {
 			} else {
 				// Echo the character back to the user
 				// Echo the character back to the user
@@ -188,31 +171,14 @@ func (s *SSHServer) handleSession(channel ssh.Channel, requests <-chan *ssh.Requ
 	}
 	}
 }
 }
 
 
-// getOrCreateHostKey loads existing key or generates new one
+// getOrCreateHostKey generates a new ephemeral host key
 func getOrCreateHostKey() (ssh.Signer, error) {
 func getOrCreateHostKey() (ssh.Signer, error) {
-	keyPath := "ssh_host_key"
-	
-	// Try to load existing key
-	if keyData, err := os.ReadFile(keyPath); err == nil {
-		return ssh.ParsePrivateKey(keyData)
-	}
-
-	// Generate new ephemeral key (more private but less convenient)
+	// Generate new ephemeral key each time
 	// Users will see "host key changed" warnings on each restart
 	// Users will see "host key changed" warnings on each restart
 	key, err := rsa.GenerateKey(rand.Reader, 2048)
 	key, err := rsa.GenerateKey(rand.Reader, 2048)
 	if err != nil {
 	if err != nil {
 		return nil, err
 		return nil, err
 	}
 	}
 
 
-	// Optionally save for convenience (comment out for max privacy)
-	keyData := pem.EncodeToMemory(&pem.Block{
-		Type:  "RSA PRIVATE KEY",
-		Bytes: x509.MarshalPKCS1PrivateKey(key),
-	})
-	
-	if err := os.WriteFile(keyPath, keyData, 0600); err != nil {
-		// Couldn't save host key - continue anyway
-	}
-
 	return ssh.NewSignerFromKey(key)
 	return ssh.NewSignerFromKey(key)
 }
 }

+ 26 - 104
util.go

@@ -4,115 +4,37 @@ import (
 	"net"
 	"net"
 	"sync"
 	"sync"
 	"time"
 	"time"
-)
-
-// Simple in-memory rate limiter
-// To disable: Remove NewRateLimiter calls from protocol files
-type RateLimiter struct {
-	requests map[string][]time.Time
-	mu       sync.Mutex
-	limit    int           // requests per window
-	window   time.Duration // time window
-	stopCh   chan struct{} // for cleanup goroutine
-}
 
 
-func NewRateLimiter(limit int, window time.Duration) *RateLimiter {
-	r := &RateLimiter{
-		requests: make(map[string][]time.Time),
-		limit:    limit,
-		window:   window,
-		stopCh:   make(chan struct{}),
-	}
-	
-	// Start cleanup goroutine
-	go r.cleanup()
-	
-	return r
-}
+	"golang.org/x/time/rate"
+)
 
 
-func (r *RateLimiter) Allow(addr string) bool {
-	r.mu.Lock()
-	defer r.mu.Unlock()
+// Rate limiters per IP with automatic cleanup
+var (
+	limiters sync.Map
+	lastClean = time.Now()
+)
 
 
-	// Extract IP without port (addr might be "1.2.3.4:5678" or just "1.2.3.4")
-	ip, _, _ := net.SplitHostPort(addr)
-	if ip == "" {
-		ip = addr // addr was already just an IP
+func rateLimitAllow(addr string) bool {
+	// Extract just the IP
+	ip := addr
+	if host, _, err := net.SplitHostPort(addr); err == nil {
+		ip = host
 	}
 	}
-
-	now := time.Now()
-	cutoff := now.Add(-r.window)
-
-	// Get or create request list
-	requests := r.requests[ip]
 	
 	
-	// Remove old requests
-	valid := []time.Time{}
-	for _, t := range requests {
-		if t.After(cutoff) {
-			valid = append(valid, t)
-		}
-	}
-
-	// Check limit
-	if len(valid) >= r.limit {
-		return false
+	// Clean old entries every hour
+	if time.Since(lastClean) > time.Hour {
+		lastClean = time.Now()
+		limiters.Range(func(key, value interface{}) bool {
+			if l, ok := value.(*rate.Limiter); ok && l.Tokens() >= 10 {
+				limiters.Delete(key)
+			}
+			return true
+		})
 	}
 	}
-
-	// Add new request
-	valid = append(valid, now)
-	r.requests[ip] = valid
 	
 	
-	return true
-}
-
-// Periodic cleanup to prevent unbounded memory growth
-func (r *RateLimiter) cleanup() {
-	ticker := time.NewTicker(r.window)
-	defer ticker.Stop()
+	// Get or create limiter for this IP
+	limiterInterface, _ := limiters.LoadOrStore(ip, rate.NewLimiter(100.0/60, 10))
+	limiter := limiterInterface.(*rate.Limiter)
 	
 	
-	for {
-		select {
-		case <-ticker.C:
-			r.mu.Lock()
-			now := time.Now()
-			cutoff := now.Add(-r.window)
-			
-			// Remove IPs with no recent requests
-			for ip, requests := range r.requests {
-				valid := []time.Time{}
-				for _, t := range requests {
-					if t.After(cutoff) {
-						valid = append(valid, t)
-					}
-				}
-				
-				if len(valid) == 0 {
-					delete(r.requests, ip)
-				} else {
-					r.requests[ip] = valid
-				}
-			}
-			r.mu.Unlock()
-			
-		case <-r.stopCh:
-			return
-		}
-	}
-}
-
-// Stop the rate limiter cleanup
-func (r *RateLimiter) Stop() {
-	select {
-	case <-r.stopCh:
-		// Already closed
-	default:
-		close(r.stopCh)
-	}
-}
-
-// Add other shared utilities here as needed
-// Each should be self-contained and optional
-
-// Global rate limiter instance
-var rateLimiter = NewRateLimiter(100, time.Minute) // 100 requests per minute per IP
+	return limiter.Allow()
+}