server.go 7.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309
  1. package server
  2. import (
  3. "bytes"
  4. "embed"
  5. _ "embed" // required for go:embed
  6. "encoding/json"
  7. "fmt"
  8. "golang.org/x/time/rate"
  9. "heckel.io/ntfy/config"
  10. "io"
  11. "log"
  12. "net"
  13. "net/http"
  14. "regexp"
  15. "strings"
  16. "sync"
  17. "time"
  18. )
  19. // Server is the main server
  20. type Server struct {
  21. config *config.Config
  22. topics map[string]*topic
  23. visitors map[string]*visitor
  24. mu sync.Mutex
  25. }
  26. // visitor represents an API user, and its associated rate.Limiter used for rate limiting
  27. type visitor struct {
  28. limiter *rate.Limiter
  29. seen time.Time
  30. }
  31. // errHTTP is a generic HTTP error for any non-200 HTTP error
  32. type errHTTP struct {
  33. Code int
  34. Status string
  35. }
  36. func (e errHTTP) Error() string {
  37. return fmt.Sprintf("http: %s", e.Status)
  38. }
  39. const (
  40. messageLimit = 1024
  41. visitorExpungeAfter = 30 * time.Minute
  42. )
  43. var (
  44. topicRegex = regexp.MustCompile(`^/[^/]+$`)
  45. jsonRegex = regexp.MustCompile(`^/[^/]+/json$`)
  46. sseRegex = regexp.MustCompile(`^/[^/]+/sse$`)
  47. rawRegex = regexp.MustCompile(`^/[^/]+/raw$`)
  48. staticRegex = regexp.MustCompile(`^/static/.+`)
  49. //go:embed "index.html"
  50. indexSource string
  51. //go:embed static
  52. webStaticFs embed.FS
  53. errHTTPNotFound = &errHTTP{http.StatusNotFound, http.StatusText(http.StatusNotFound)}
  54. errHTTPTooManyRequests = &errHTTP{http.StatusTooManyRequests, http.StatusText(http.StatusTooManyRequests)}
  55. )
  56. func New(conf *config.Config) *Server {
  57. return &Server{
  58. config: conf,
  59. topics: make(map[string]*topic),
  60. visitors: make(map[string]*visitor),
  61. }
  62. }
  63. func (s *Server) Run() error {
  64. go func() {
  65. ticker := time.NewTicker(s.config.ManagerInterval)
  66. for {
  67. <-ticker.C
  68. s.updateStatsAndExpire()
  69. }
  70. }()
  71. return s.listenAndServe()
  72. }
  73. func (s *Server) listenAndServe() error {
  74. log.Printf("Listening on %s", s.config.ListenHTTP)
  75. http.HandleFunc("/", s.handle)
  76. return http.ListenAndServe(s.config.ListenHTTP, nil)
  77. }
  78. func (s *Server) updateStatsAndExpire() {
  79. s.mu.Lock()
  80. defer s.mu.Unlock()
  81. // Expire visitors from rate visitors map
  82. for ip, v := range s.visitors {
  83. if time.Since(v.seen) > visitorExpungeAfter {
  84. delete(s.visitors, ip)
  85. }
  86. }
  87. // Print stats
  88. var subscribers, messages int
  89. for _, t := range s.topics {
  90. subs, msgs := t.Stats()
  91. subscribers += subs
  92. messages += msgs
  93. }
  94. log.Printf("Stats: %d topic(s), %d subscriber(s), %d message(s) sent, %d visitor(s)",
  95. len(s.topics), subscribers, messages, len(s.visitors))
  96. }
  97. func (s *Server) handle(w http.ResponseWriter, r *http.Request) {
  98. if err := s.handleInternal(w, r); err != nil {
  99. if e, ok := err.(*errHTTP); ok {
  100. s.fail(w, r, e.Code, e)
  101. } else {
  102. s.fail(w, r, http.StatusInternalServerError, err)
  103. }
  104. }
  105. }
  106. func (s *Server) handleInternal(w http.ResponseWriter, r *http.Request) error {
  107. v := s.visitor(r.RemoteAddr)
  108. if !v.limiter.Allow() {
  109. return errHTTPTooManyRequests
  110. }
  111. if r.Method == http.MethodGet && r.URL.Path == "/" {
  112. return s.handleHome(w, r)
  113. } else if r.Method == http.MethodGet && staticRegex.MatchString(r.URL.Path) {
  114. return s.handleStatic(w, r)
  115. } else if r.Method == http.MethodGet && jsonRegex.MatchString(r.URL.Path) {
  116. return s.handleSubscribeJSON(w, r)
  117. } else if r.Method == http.MethodGet && sseRegex.MatchString(r.URL.Path) {
  118. return s.handleSubscribeSSE(w, r)
  119. } else if r.Method == http.MethodGet && rawRegex.MatchString(r.URL.Path) {
  120. return s.handleSubscribeRaw(w, r)
  121. } else if (r.Method == http.MethodPut || r.Method == http.MethodPost) && topicRegex.MatchString(r.URL.Path) {
  122. return s.handlePublishHTTP(w, r)
  123. } else if r.Method == http.MethodOptions {
  124. return s.handleOptions(w, r)
  125. }
  126. return errHTTPNotFound
  127. }
  128. func (s *Server) handleHome(w http.ResponseWriter, r *http.Request) error {
  129. _, err := io.WriteString(w, indexSource)
  130. return err
  131. }
  132. func (s *Server) handlePublishHTTP(w http.ResponseWriter, r *http.Request) error {
  133. t, err := s.topic(r.URL.Path[1:])
  134. if err != nil {
  135. return err
  136. }
  137. reader := io.LimitReader(r.Body, messageLimit)
  138. b, err := io.ReadAll(reader)
  139. if err != nil {
  140. return err
  141. }
  142. msg := &message{
  143. Time: time.Now().UnixMilli(),
  144. Message: string(b),
  145. }
  146. if err := t.Publish(msg); err != nil {
  147. return err
  148. }
  149. w.Header().Set("Access-Control-Allow-Origin", "*") // CORS, allow cross-origin requests
  150. return nil
  151. }
  152. func (s *Server) handleSubscribeJSON(w http.ResponseWriter, r *http.Request) error {
  153. t := s.createTopic(strings.TrimSuffix(r.URL.Path[1:], "/json")) // Hack
  154. subscriberID := t.Subscribe(func(msg *message) error {
  155. if err := json.NewEncoder(w).Encode(&msg); err != nil {
  156. return err
  157. }
  158. if fl, ok := w.(http.Flusher); ok {
  159. fl.Flush()
  160. }
  161. return nil
  162. })
  163. defer s.unsubscribe(t, subscriberID)
  164. w.Header().Set("Access-Control-Allow-Origin", "*") // CORS, allow cross-origin requests
  165. select {
  166. case <-t.ctx.Done():
  167. case <-r.Context().Done():
  168. }
  169. return nil
  170. }
  171. func (s *Server) handleSubscribeSSE(w http.ResponseWriter, r *http.Request) error {
  172. t := s.createTopic(strings.TrimSuffix(r.URL.Path[1:], "/sse")) // Hack
  173. subscriberID := t.Subscribe(func(msg *message) error {
  174. var buf bytes.Buffer
  175. if err := json.NewEncoder(&buf).Encode(&msg); err != nil {
  176. return err
  177. }
  178. m := fmt.Sprintf("data: %s\n", buf.String())
  179. if _, err := io.WriteString(w, m); err != nil {
  180. return err
  181. }
  182. if fl, ok := w.(http.Flusher); ok {
  183. fl.Flush()
  184. }
  185. return nil
  186. })
  187. defer s.unsubscribe(t, subscriberID)
  188. w.Header().Set("Content-Type", "text/event-stream")
  189. w.Header().Set("Access-Control-Allow-Origin", "*") // CORS, allow cross-origin requests
  190. if _, err := io.WriteString(w, "event: open\n\n"); err != nil {
  191. return err
  192. }
  193. if fl, ok := w.(http.Flusher); ok {
  194. fl.Flush()
  195. }
  196. select {
  197. case <-t.ctx.Done():
  198. case <-r.Context().Done():
  199. }
  200. return nil
  201. }
  202. func (s *Server) handleSubscribeRaw(w http.ResponseWriter, r *http.Request) error {
  203. t := s.createTopic(strings.TrimSuffix(r.URL.Path[1:], "/raw")) // Hack
  204. subscriberID := t.Subscribe(func(msg *message) error {
  205. m := strings.ReplaceAll(msg.Message, "\n", " ") + "\n"
  206. if _, err := io.WriteString(w, m); err != nil {
  207. return err
  208. }
  209. if fl, ok := w.(http.Flusher); ok {
  210. fl.Flush()
  211. }
  212. return nil
  213. })
  214. defer s.unsubscribe(t, subscriberID)
  215. select {
  216. case <-t.ctx.Done():
  217. case <-r.Context().Done():
  218. }
  219. return nil
  220. }
  221. func (s *Server) handleOptions(w http.ResponseWriter, r *http.Request) error {
  222. w.Header().Set("Access-Control-Allow-Origin", "*") // CORS, allow cross-origin requests
  223. w.Header().Set("Access-Control-Allow-Methods", "GET, PUT, POST")
  224. return nil
  225. }
  226. func (s *Server) handleStatic(w http.ResponseWriter, r *http.Request) error {
  227. http.FileServer(http.FS(webStaticFs)).ServeHTTP(w, r)
  228. return nil
  229. }
  230. func (s *Server) createTopic(id string) *topic {
  231. s.mu.Lock()
  232. defer s.mu.Unlock()
  233. if _, ok := s.topics[id]; !ok {
  234. s.topics[id] = newTopic(id)
  235. }
  236. return s.topics[id]
  237. }
  238. func (s *Server) topic(topicID string) (*topic, error) {
  239. s.mu.Lock()
  240. defer s.mu.Unlock()
  241. c, ok := s.topics[topicID]
  242. if !ok {
  243. return nil, errHTTPNotFound
  244. }
  245. return c, nil
  246. }
  247. func (s *Server) unsubscribe(t *topic, subscriberID int) {
  248. s.mu.Lock()
  249. defer s.mu.Unlock()
  250. if subscribers := t.Unsubscribe(subscriberID); subscribers == 0 {
  251. delete(s.topics, t.id)
  252. }
  253. }
  254. // visitor creates or retrieves a rate.Limiter for the given visitor.
  255. // This function was taken from https://www.alexedwards.net/blog/how-to-rate-limit-http-requests (MIT).
  256. func (s *Server) visitor(remoteAddr string) *visitor {
  257. s.mu.Lock()
  258. defer s.mu.Unlock()
  259. ip, _, err := net.SplitHostPort(remoteAddr)
  260. if err != nil {
  261. ip = remoteAddr // This should not happen in real life; only in tests.
  262. }
  263. v, exists := s.visitors[ip]
  264. if !exists {
  265. v = &visitor{
  266. rate.NewLimiter(s.config.Limit, s.config.LimitBurst),
  267. time.Now(),
  268. }
  269. s.visitors[ip] = v
  270. return v
  271. }
  272. v.seen = time.Now()
  273. return v
  274. }
  275. func (s *Server) fail(w http.ResponseWriter, r *http.Request, code int, err error) {
  276. log.Printf("[%s] %s - %d - %s", r.RemoteAddr, r.Method, code, err.Error())
  277. w.WriteHeader(code)
  278. io.WriteString(w, fmt.Sprintf("%s\n", http.StatusText(code)))
  279. }