Просмотр исходного кода

Fix memory leaks and DNS panic with O(1) rate limiting

ajasibley 1 месяц назад
Родитель
Сommit
9246288582
3 измененных файлов с 43 добавлено и 36 удалено
  1. 1 7
      dns.go
  2. 16 12
      http.go
  3. 26 17
      util.go

+ 1 - 7
dns.go

@@ -49,13 +49,7 @@ func handleDNS(w dns.ResponseWriter, r *dns.Msg) {
 		done := make(chan bool)
 
 		go func() {
-			if _, err := LLM(dnsPrompt, ch); err != nil {
-				select {
-				case ch <- "Error: " + err.Error():
-				case <-done:
-				}
-			}
-			// Don't close ch here - LLM function already does it with defer
+			LLM(dnsPrompt, ch)
 		}()
 
 		var response strings.Builder

+ 16 - 12
http.go

@@ -164,23 +164,23 @@ func handleRoot(w http.ResponseWriter, r *http.Request) {
 			fmt.Fprintf(w, "<div class=\"q\">%s</div>\n<div class=\"a\">", html.EscapeString(query))
 			flusher.Flush()
 
-			ch := make(chan string)
+			ch := make(chan string, 10)
 			go func() {
 				htmlPrompt := htmlPromptPrefix + prompt
 				LLM(htmlPrompt, ch)
 			}()
 
-			response := ""
+			var response strings.Builder
 			for chunk := range ch {
 				if _, err := fmt.Fprint(w, chunk); err != nil {
 					return
 				}
-				response += chunk
+				response.WriteString(chunk)
 				flusher.Flush()
 			}
 			fmt.Fprint(w, "</div>\n")
 
-			finalHistory := history + fmt.Sprintf("Q: %s\nA: %s\n\n", query, response)
+			finalHistory := history + fmt.Sprintf("Q: %s\nA: %s\n\n", query, response.String())
 			fmt.Fprintf(w, htmlFooterTemplate, html.EscapeString(finalHistory))
 			return
 		}
@@ -196,15 +196,15 @@ func handleRoot(w http.ResponseWriter, r *http.Request) {
 			fmt.Fprintf(w, "Q: %s\nA: ", query)
 			flusher.Flush()
 
-			ch := make(chan string)
+			ch := make(chan string, 10)
 			go func() {
 				LLM(prompt, ch)
 			}()
 
-			response := ""
 			for chunk := range ch {
-				fmt.Fprint(w, chunk)
-				response += chunk
+				if _, err := fmt.Fprint(w, chunk); err != nil {
+					return
+				}
 				flusher.Flush()
 			}
 			fmt.Fprint(w, "\n")
@@ -260,13 +260,15 @@ func handleRoot(w http.ResponseWriter, r *http.Request) {
 			return
 		}
 
-		ch := make(chan string)
+		ch := make(chan string, 10)
 		go func() {
 			LLM(prompt, ch)
 		}()
 
 		for chunk := range ch {
-			fmt.Fprintf(w, "data: %s\n\n", chunk)
+			if _, err := fmt.Fprintf(w, "data: %s\n\n", chunk); err != nil {
+				return
+			}
 			flusher.Flush()
 		}
 		fmt.Fprintf(w, "data: [DONE]\n\n")
@@ -368,7 +370,7 @@ func handleChatCompletions(w http.ResponseWriter, r *http.Request) {
 			return
 		}
 
-		ch := make(chan string)
+		ch := make(chan string, 10)
 		go LLM(messages, ch)
 
 		for chunk := range ch {
@@ -387,7 +389,9 @@ func handleChatCompletions(w http.ResponseWriter, r *http.Request) {
 				fmt.Fprintf(w, "data: Failed to marshal response\n\n")
 				return
 			}
-			fmt.Fprintf(w, "data: %s\n\n", data)
+			if _, err := fmt.Fprintf(w, "data: %s\n\n", data); err != nil {
+				return
+			}
 			flusher.Flush()
 		}
 		fmt.Fprintf(w, "data: [DONE]\n\n")

+ 26 - 17
util.go

@@ -3,38 +3,47 @@ package main
 import (
 	"net"
 	"sync"
-	"time"
+	"sync/atomic"
 
 	"golang.org/x/time/rate"
 )
 
-// Rate limiters per IP with automatic cleanup
+const maxEntries = 10000 // Rotate when current map reaches this size (~2.5MB)
+
 var (
-	limiters  sync.Map
-	lastClean = time.Now()
+	current      = &sync.Map{}
+	previous     = &sync.Map{}
+	currentCount int64
 )
 
 func rateLimitAllow(addr string) bool {
-	// Extract just the IP
 	ip := addr
 	if host, _, err := net.SplitHostPort(addr); err == nil {
 		ip = host
 	}
 
-	// Clean old entries every hour
-	if time.Since(lastClean) > time.Hour {
-		lastClean = time.Now()
-		limiters.Range(func(key, value interface{}) bool {
-			if l, ok := value.(*rate.Limiter); ok && l.Tokens() >= 10 {
-				limiters.Delete(key)
-			}
-			return true
-		})
+	if atomic.LoadInt64(&currentCount) >= maxEntries {
+		rotate()
+	}
+
+	if val, ok := current.Load(ip); ok {
+		return val.(*rate.Limiter).Allow()
 	}
 
-	// Get or create limiter for this IP
-	limiterInterface, _ := limiters.LoadOrStore(ip, rate.NewLimiter(100.0/60, 10))
-	limiter := limiterInterface.(*rate.Limiter)
+	if val, ok := previous.Load(ip); ok {
+		current.Store(ip, val)
+		atomic.AddInt64(&currentCount, 1)
+		return val.(*rate.Limiter).Allow()
+	}
 
+	limiter := rate.NewLimiter(100.0/60, 10)
+	current.Store(ip, limiter)
+	atomic.AddInt64(&currentCount, 1)
 	return limiter.Allow()
 }
+
+func rotate() {
+	previous = current
+	current = &sync.Map{}
+	atomic.StoreInt64(&currentCount, 0)
+}