Kaynağa Gözat

Add dynamic streaming support

ajasibley 5 ay önce
ebeveyn
işleme
d83125f9a7
1 değiştirilmiş dosya ile 87 ekleme ve 45 silme
  1. 87 45
      http.go

+ 87 - 45
http.go

@@ -10,7 +10,7 @@ import (
 	"time"
 )
 
-const minimalHTML = `<!DOCTYPE html>
+const htmlHeader = `<!DOCTYPE html>
 <html>
 <head>
     <title>ch.at</title>
@@ -24,11 +24,13 @@ const minimalHTML = `<!DOCTYPE html>
 <body>
     <h1>ch.at</h1>
     <p><i>pronounced "ch-dot-at"</i></p>
-    <pre>%s</pre>
+    <pre>`
+
+const htmlFooterTemplate = `</pre>
     <form method="POST" action="/">
         <input type="text" name="q" placeholder="Type your message..." autofocus>
-        <textarea name="h" style="display:none">%s</textarea>
         <input type="submit" value="Send">
+        <textarea name="h" style="display:none">%s</textarea>
     </form>
     <p><a href="/">Clear History</a> • <a href="https://github.com/Deep-ai-inc/ch.at#readme">About</a></p>
 </body>
@@ -65,14 +67,13 @@ func handleRoot(w http.ResponseWriter, r *http.Request) {
 		query = r.FormValue("q")
 		history = r.FormValue("h")
 
-		// Limit history size to prevent abuse
-		if len(history) > 2048 {
-			history = history[len(history)-2048:]
+		// Limit history size to ensure compatibility
+		if len(history) > 65536 {
+			history = history[len(history)-65536:]
 		}
 
-		// If no form fields, treat body as raw query (for curl)
 		if query == "" {
-			body, err := io.ReadAll(io.LimitReader(r.Body, 4096)) // Limit body size
+			body, err := io.ReadAll(io.LimitReader(r.Body, 65536)) // Limit body size
 			if err != nil {
 				http.Error(w, "Failed to read request body", http.StatusBadRequest)
 				return
@@ -81,46 +82,100 @@ func handleRoot(w http.ResponseWriter, r *http.Request) {
 		}
 	} else {
 		query = r.URL.Query().Get("q")
-		// Also support path-based queries like /what-is-go
+		// Support path-based queries like /what-is-go
 		if query == "" && r.URL.Path != "/" {
 			query = strings.ReplaceAll(strings.TrimPrefix(r.URL.Path, "/"), "-", " ")
 		}
 	}
 
+	accept := r.Header.Get("Accept")
+	wantsJSON := strings.Contains(accept, "application/json")
+	wantsHTML := strings.Contains(accept, "text/html")
+	wantsStream := strings.Contains(accept, "text/event-stream")
 
 	if query != "" {
-		// Build prompt with history
 		prompt = query
 		if history != "" {
 			prompt = history + "Q: " + query
 		}
 
+		if wantsHTML && r.Header.Get("Accept") != "application/json" {
+			w.Header().Set("Content-Type", "text/html; charset=utf-8")
+			w.Header().Set("Transfer-Encoding", "chunked")
+			w.Header().Set("X-Accel-Buffering", "no")
+			w.Header().Set("Cache-Control", "no-cache")
+			flusher := w.(http.Flusher)
+
+			displayHistory := history
+			headerSize := len(htmlHeader)
+			historySize := len(html.EscapeString(history))
+			querySize := len(html.EscapeString(query))
+			currentSize := headerSize + historySize + querySize + 10
+
+			// Browser streaming needs significant content - working version used 6KB
+			const minThreshold = 6144 // 6KB threshold (matching what worked before)
+			if currentSize < minThreshold {
+				// Each zero-width space is 3 bytes in UTF-8
+				paddingNeeded := (minThreshold - currentSize) / 3
+				if paddingNeeded > 0 {
+					padding := strings.Repeat("\u200B", paddingNeeded)
+					displayHistory = padding + history
+				}
+			}
+
+			fmt.Fprint(w, htmlHeader)
+			fmt.Fprintf(w, "%sQ: %s\nA: ", html.EscapeString(displayHistory), html.EscapeString(query))
+			flusher.Flush()
+
+			ch := make(chan string)
+			go func() {
+				if _, err := LLM(prompt, ch); err != nil {
+					ch <- err.Error()
+					close(ch)
+				}
+			}()
+
+			response := ""
+			for chunk := range ch {
+				if _, err := fmt.Fprint(w, html.EscapeString(chunk)); err != nil {
+					return
+				}
+				response += chunk
+				flusher.Flush()
+			}
+
+			finalHistory := history + fmt.Sprintf("Q: %s\nA: %s\n\n", query, response)
+			fmt.Fprintf(w, htmlFooterTemplate, html.EscapeString(finalHistory))
+			return
+		}
+
 		response, err := LLM(prompt, nil)
 		if err != nil {
 			content = err.Error()
 			errJSON, _ := json.Marshal(map[string]string{"error": err.Error()})
 			jsonResponse = string(errJSON)
 		} else {
-			// Store JSON response
 			respJSON, _ := json.Marshal(map[string]string{
 				"question": query,
 				"answer":   response,
 			})
 			jsonResponse = string(respJSON)
 
-			// Append to history
 			newExchange := fmt.Sprintf("Q: %s\nA: %s\n\n", query, response)
 			if history != "" {
 				content = history + newExchange
 			} else {
 				content = newExchange
 			}
-			// Trim history if too long (UTF-8 safe)
-			if len(content) > 2048 {
-				// Keep roughly last 600 characters (UTF-8 safe)
-				runes := []rune(content)
-				if len(runes) > 600 {
-					content = string(runes[len(runes)-600:])
+			if len(content) > 65536 {
+				newExchangeLen := len(newExchange)
+				if newExchangeLen > 65536 {
+					content = newExchange[:65536]
+				} else {
+					maxHistory := 65536 - newExchangeLen
+					if len(history) > maxHistory {
+						content = history[len(history)-maxHistory:] + newExchange
+					}
 				}
 			}
 		}
@@ -128,28 +183,20 @@ func handleRoot(w http.ResponseWriter, r *http.Request) {
 		content = history
 	}
 
-	accept := r.Header.Get("Accept")
-	wantsJSON := strings.Contains(accept, "application/json")
-	wantsHTML := strings.Contains(accept, "text/html")
-	wantsStream := strings.Contains(accept, "text/event-stream")
-
-	// Stream for curl when requested
 	if wantsStream && query != "" {
 		w.Header().Set("Content-Type", "text/event-stream")
 		w.Header().Set("Cache-Control", "no-cache")
 		w.Header().Set("Connection", "keep-alive")
-		
+
 		flusher, ok := w.(http.Flusher)
 		if !ok {
 			http.Error(w, "Streaming not supported", http.StatusInternalServerError)
 			return
 		}
 
-		// Stream response
 		ch := make(chan string)
 		go func() {
 			if _, err := LLM(prompt, ch); err != nil {
-				// Send error as SSE event
 				fmt.Fprintf(w, "data: Error: %s\n\n", err.Error())
 				flusher.Flush()
 			}
@@ -163,15 +210,15 @@ func handleRoot(w http.ResponseWriter, r *http.Request) {
 		return
 	}
 
-	// Return JSON for API requests, HTML for browsers, plain text for curl
 	if wantsJSON && jsonResponse != "" {
 		w.Header().Set("Content-Type", "application/json; charset=utf-8")
 		fmt.Fprint(w, jsonResponse)
-	} else if wantsHTML {
+	} else if wantsHTML && query == "" {
 		w.Header().Set("Content-Type", "text/html; charset=utf-8")
-		fmt.Fprintf(w, minimalHTML, html.EscapeString(content), html.EscapeString(content))
+		fmt.Fprint(w, htmlHeader)
+		fmt.Fprint(w, html.EscapeString(content))
+		fmt.Fprintf(w, htmlFooterTemplate, html.EscapeString(content))
 	} else {
-		// Default to plain text for curl and other tools
 		w.Header().Set("Content-Type", "text/plain; charset=utf-8")
 		fmt.Fprint(w, content)
 	}
@@ -206,7 +253,7 @@ func handleChatCompletions(w http.ResponseWriter, r *http.Request) {
 		http.Error(w, "Rate limit exceeded", http.StatusTooManyRequests)
 		return
 	}
-	
+
 	if r.Method != "POST" {
 		http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
 		return
@@ -218,7 +265,6 @@ func handleChatCompletions(w http.ResponseWriter, r *http.Request) {
 		return
 	}
 
-	// Convert messages to format for LLM
 	messages := make([]map[string]string, len(req.Messages))
 	for i, msg := range req.Messages {
 		messages[i] = map[string]string{
@@ -226,29 +272,27 @@ func handleChatCompletions(w http.ResponseWriter, r *http.Request) {
 			"content": msg.Content,
 		}
 	}
-	
-	
+
 	if req.Stream {
 		w.Header().Set("Content-Type", "text/event-stream")
 		w.Header().Set("Cache-Control", "no-cache")
 		w.Header().Set("Connection", "keep-alive")
-		
+
 		flusher, ok := w.(http.Flusher)
 		if !ok {
 			http.Error(w, "Streaming not supported", http.StatusInternalServerError)
 			return
 		}
-		
-		// Stream response
+
 		ch := make(chan string)
 		go LLM(messages, ch)
-		
+
 		for chunk := range ch {
 			resp := map[string]interface{}{
-				"id": fmt.Sprintf("chatcmpl-%d", time.Now().Unix()),
-				"object": "chat.completion.chunk",
+				"id":      fmt.Sprintf("chatcmpl-%d", time.Now().Unix()),
+				"object":  "chat.completion.chunk",
 				"created": time.Now().Unix(),
-				"model": req.Model,
+				"model":   req.Model,
 				"choices": []map[string]interface{}{{
 					"index": 0,
 					"delta": map[string]string{"content": chunk},
@@ -263,7 +307,7 @@ func handleChatCompletions(w http.ResponseWriter, r *http.Request) {
 			flusher.Flush()
 		}
 		fmt.Fprintf(w, "data: [DONE]\n\n")
-		
+
 	} else {
 		response, err := LLM(messages, nil)
 		if err != nil {
@@ -290,5 +334,3 @@ func handleChatCompletions(w http.ResponseWriter, r *http.Request) {
 	}
 }
 
-
-