http.go 4.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183
  1. package main
  2. import (
  3. "context"
  4. "encoding/json"
  5. "fmt"
  6. "html"
  7. "io"
  8. "net/http"
  9. "strings"
  10. "time"
  11. )
  12. const minimalHTML = `<!DOCTYPE html>
  13. <html>
  14. <head>
  15. <title>ch.at</title>
  16. <style>
  17. body { text-align: center; margin: 40px; }
  18. pre { text-align: left; max-width: 600px; margin: 20px auto; padding: 20px;
  19. white-space: pre-wrap; word-wrap: break-word; }
  20. input[type="text"] { width: 300px; }
  21. </style>
  22. </head>
  23. <body>
  24. <h1>ch.at</h1>
  25. <p><i>pronounced "ch-dot-at"</i></p>
  26. <pre>%s</pre>
  27. <form method="POST" action="/">
  28. <input type="text" name="q" placeholder="Type your message..." autofocus>
  29. <textarea name="h" style="display:none">%s</textarea>
  30. <input type="submit" value="Send">
  31. </form>
  32. <p><a href="/">Clear History</a> • <a href="https://github.com/ch-at/ch.at#readme">About</a></p>
  33. </body>
  34. </html>`
  35. type HTTPServer struct {
  36. port int
  37. }
  38. func NewHTTPServer(port int) *HTTPServer {
  39. return &HTTPServer{port: port}
  40. }
  41. func (s *HTTPServer) Start() error {
  42. http.HandleFunc("/", s.handleRoot)
  43. addr := fmt.Sprintf(":%d", s.port)
  44. fmt.Printf("HTTP server listening on %s\n", addr)
  45. return http.ListenAndServe(addr, nil)
  46. }
  47. func (s *HTTPServer) StartTLS(port int, certFile, keyFile string) error {
  48. addr := fmt.Sprintf(":%d", port)
  49. fmt.Printf("HTTPS server listening on %s\n", addr)
  50. return http.ListenAndServeTLS(addr, certFile, keyFile, nil)
  51. }
  52. func (s *HTTPServer) handleRoot(w http.ResponseWriter, r *http.Request) {
  53. if !rateLimiter.Allow(r.RemoteAddr) {
  54. http.Error(w, "Rate limit exceeded. Please try again later.", http.StatusTooManyRequests)
  55. return
  56. }
  57. var query, history, prompt string
  58. content := ""
  59. jsonResponse := ""
  60. if r.Method == "POST" {
  61. if err := r.ParseForm(); err != nil {
  62. http.Error(w, "Failed to parse form", http.StatusBadRequest)
  63. return
  64. }
  65. query = r.FormValue("q")
  66. history = r.FormValue("h")
  67. // Limit history size to prevent abuse
  68. if len(history) > 2048 {
  69. history = history[len(history)-2048:]
  70. }
  71. // If no form fields, treat body as raw query (for curl)
  72. if query == "" {
  73. body, err := io.ReadAll(io.LimitReader(r.Body, 4096)) // Limit body size
  74. if err != nil {
  75. http.Error(w, "Failed to read request body", http.StatusBadRequest)
  76. return
  77. }
  78. query = string(body)
  79. }
  80. } else {
  81. // GET request - no history
  82. query = r.URL.Query().Get("q")
  83. }
  84. ctx, cancel := context.WithTimeout(r.Context(), 30*time.Second)
  85. defer cancel()
  86. if query != "" {
  87. // Build prompt with history
  88. prompt = query
  89. if history != "" {
  90. prompt = history + "Q: " + query
  91. }
  92. response, err := getLLMResponse(ctx, prompt)
  93. if err != nil {
  94. content = fmt.Sprintf("Error: %s", err.Error())
  95. errJSON, _ := json.Marshal(map[string]string{"error": err.Error()})
  96. jsonResponse = string(errJSON)
  97. } else {
  98. // Store JSON response
  99. respJSON, _ := json.Marshal(map[string]string{
  100. "question": query,
  101. "answer": response,
  102. })
  103. jsonResponse = string(respJSON)
  104. // Append to history
  105. newExchange := fmt.Sprintf("Q: %s\nA: %s\n\n", query, response)
  106. if history != "" {
  107. content = history + newExchange
  108. } else {
  109. content = newExchange
  110. }
  111. // Trim history if too long (UTF-8 safe)
  112. if len(content) > 2048 {
  113. // UTF-8 continuation bytes start with 10xxxxxx (0x80-0xBF)
  114. // Find a character boundary to avoid splitting multi-byte chars
  115. for i := len(content) - 2048; i < len(content)-2040; i++ {
  116. if content[i]&0xC0 != 0x80 { // Not a continuation byte
  117. content = content[i:]
  118. break
  119. }
  120. }
  121. }
  122. }
  123. } else if history != "" {
  124. content = history
  125. }
  126. accept := r.Header.Get("Accept")
  127. // Stream for curl when requested
  128. if strings.Contains(accept, "text/event-stream") && query != "" {
  129. w.Header().Set("Content-Type", "text/event-stream")
  130. w.Header().Set("Cache-Control", "no-cache")
  131. w.Header().Set("Connection", "keep-alive")
  132. flusher, ok := w.(http.Flusher)
  133. if !ok {
  134. http.Error(w, "Streaming not supported", http.StatusInternalServerError)
  135. return
  136. }
  137. stream, err := getLLMResponseStream(ctx, prompt)
  138. if err != nil {
  139. fmt.Fprintf(w, "data: Error: %s\n\n", err.Error())
  140. return
  141. }
  142. for chunk := range stream {
  143. fmt.Fprintf(w, "data: %s\n\n", chunk)
  144. flusher.Flush()
  145. }
  146. fmt.Fprintf(w, "data: [DONE]\n\n")
  147. return
  148. }
  149. // Return JSON for API requests, HTML for browsers, plain text for curl
  150. if strings.Contains(accept, "application/json") && jsonResponse != "" {
  151. w.Header().Set("Content-Type", "application/json; charset=utf-8")
  152. fmt.Fprint(w, jsonResponse)
  153. } else if strings.Contains(accept, "text/html") {
  154. w.Header().Set("Content-Type", "text/html; charset=utf-8")
  155. fmt.Fprintf(w, minimalHTML, html.EscapeString(content), html.EscapeString(content))
  156. } else {
  157. // Default to plain text for curl and other tools
  158. w.Header().Set("Content-Type", "text/plain; charset=utf-8")
  159. fmt.Fprint(w, content)
  160. }
  161. }