server.go 10 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382
  1. package server
  2. import (
  3. "bytes"
  4. "context"
  5. "embed"
  6. _ "embed" // required for go:embed
  7. "encoding/json"
  8. firebase "firebase.google.com/go"
  9. "firebase.google.com/go/messaging"
  10. "fmt"
  11. "google.golang.org/api/option"
  12. "heckel.io/ntfy/config"
  13. "io"
  14. "log"
  15. "net"
  16. "net/http"
  17. "regexp"
  18. "strconv"
  19. "strings"
  20. "sync"
  21. "time"
  22. )
  23. // TODO add "max messages in a topic" limit
  24. // TODO implement persistence
  25. // TODO implement "since=<ID>"
  26. // Server is the main server
  27. type Server struct {
  28. config *config.Config
  29. topics map[string]*topic
  30. visitors map[string]*visitor
  31. firebase subscriber
  32. messages int64
  33. mu sync.Mutex
  34. }
  35. // errHTTP is a generic HTTP error for any non-200 HTTP error
  36. type errHTTP struct {
  37. Code int
  38. Status string
  39. }
  40. func (e errHTTP) Error() string {
  41. return fmt.Sprintf("http: %s", e.Status)
  42. }
  43. const (
  44. messageLimit = 1024
  45. )
  46. var (
  47. topicRegex = regexp.MustCompile(`^/[^/]+$`)
  48. jsonRegex = regexp.MustCompile(`^/[^/]+/json$`)
  49. sseRegex = regexp.MustCompile(`^/[^/]+/sse$`)
  50. rawRegex = regexp.MustCompile(`^/[^/]+/raw$`)
  51. staticRegex = regexp.MustCompile(`^/static/.+`)
  52. //go:embed "index.html"
  53. indexSource string
  54. //go:embed static
  55. webStaticFs embed.FS
  56. errHTTPBadRequest = &errHTTP{http.StatusBadRequest, http.StatusText(http.StatusBadRequest)}
  57. errHTTPNotFound = &errHTTP{http.StatusNotFound, http.StatusText(http.StatusNotFound)}
  58. errHTTPTooManyRequests = &errHTTP{http.StatusTooManyRequests, http.StatusText(http.StatusTooManyRequests)}
  59. )
  60. func New(conf *config.Config) (*Server, error) {
  61. var firebaseSubscriber subscriber
  62. if conf.FirebaseKeyFile != "" {
  63. var err error
  64. firebaseSubscriber, err = createFirebaseSubscriber(conf)
  65. if err != nil {
  66. return nil, err
  67. }
  68. }
  69. return &Server{
  70. config: conf,
  71. firebase: firebaseSubscriber,
  72. topics: make(map[string]*topic),
  73. visitors: make(map[string]*visitor),
  74. }, nil
  75. }
  76. func createFirebaseSubscriber(conf *config.Config) (subscriber, error) {
  77. fb, err := firebase.NewApp(context.Background(), nil, option.WithCredentialsFile(conf.FirebaseKeyFile))
  78. if err != nil {
  79. return nil, err
  80. }
  81. msg, err := fb.Messaging(context.Background())
  82. if err != nil {
  83. return nil, err
  84. }
  85. return func(m *message) error {
  86. _, err := msg.Send(context.Background(), &messaging.Message{
  87. Topic: m.Topic,
  88. Data: map[string]string{
  89. "id": m.ID,
  90. "time": fmt.Sprintf("%d", m.Time),
  91. "event": m.Event,
  92. "topic": m.Topic,
  93. "message": m.Message,
  94. },
  95. })
  96. return err
  97. }, nil
  98. }
  99. func (s *Server) Run() error {
  100. go func() {
  101. ticker := time.NewTicker(s.config.ManagerInterval)
  102. for {
  103. <-ticker.C
  104. s.updateStatsAndExpire()
  105. }
  106. }()
  107. return s.listenAndServe()
  108. }
  109. func (s *Server) listenAndServe() error {
  110. log.Printf("Listening on %s", s.config.ListenHTTP)
  111. http.HandleFunc("/", s.handle)
  112. return http.ListenAndServe(s.config.ListenHTTP, nil)
  113. }
  114. func (s *Server) handle(w http.ResponseWriter, r *http.Request) {
  115. if err := s.handleInternal(w, r); err != nil {
  116. if e, ok := err.(*errHTTP); ok {
  117. s.fail(w, r, e.Code, e)
  118. } else {
  119. s.fail(w, r, http.StatusInternalServerError, err)
  120. }
  121. }
  122. }
  123. func (s *Server) handleInternal(w http.ResponseWriter, r *http.Request) error {
  124. v := s.visitor(r.RemoteAddr)
  125. if err := v.RequestAllowed(); err != nil {
  126. return err
  127. }
  128. if r.Method == http.MethodGet && r.URL.Path == "/" {
  129. return s.handleHome(w, r)
  130. } else if r.Method == http.MethodGet && staticRegex.MatchString(r.URL.Path) {
  131. return s.handleStatic(w, r)
  132. } else if (r.Method == http.MethodPut || r.Method == http.MethodPost) && topicRegex.MatchString(r.URL.Path) {
  133. return s.handlePublish(w, r, v)
  134. } else if r.Method == http.MethodGet && jsonRegex.MatchString(r.URL.Path) {
  135. return s.handleSubscribeJSON(w, r, v)
  136. } else if r.Method == http.MethodGet && sseRegex.MatchString(r.URL.Path) {
  137. return s.handleSubscribeSSE(w, r, v)
  138. } else if r.Method == http.MethodGet && rawRegex.MatchString(r.URL.Path) {
  139. return s.handleSubscribeRaw(w, r, v)
  140. } else if r.Method == http.MethodOptions {
  141. return s.handleOptions(w, r)
  142. }
  143. return errHTTPNotFound
  144. }
  145. func (s *Server) handleHome(w http.ResponseWriter, r *http.Request) error {
  146. _, err := io.WriteString(w, indexSource)
  147. return err
  148. }
  149. func (s *Server) handleStatic(w http.ResponseWriter, r *http.Request) error {
  150. http.FileServer(http.FS(webStaticFs)).ServeHTTP(w, r)
  151. return nil
  152. }
  153. func (s *Server) handlePublish(w http.ResponseWriter, r *http.Request, v *visitor) error {
  154. t, err := s.topic(r.URL.Path[1:])
  155. if err != nil {
  156. return err
  157. }
  158. reader := io.LimitReader(r.Body, messageLimit)
  159. b, err := io.ReadAll(reader)
  160. if err != nil {
  161. return err
  162. }
  163. if err := t.Publish(newDefaultMessage(t.id, string(b))); err != nil {
  164. return err
  165. }
  166. w.Header().Set("Access-Control-Allow-Origin", "*") // CORS, allow cross-origin requests
  167. s.mu.Lock()
  168. s.messages++
  169. s.mu.Unlock()
  170. return nil
  171. }
  172. func (s *Server) handleSubscribeJSON(w http.ResponseWriter, r *http.Request, v *visitor) error {
  173. encoder := func(msg *message) (string, error) {
  174. var buf bytes.Buffer
  175. if err := json.NewEncoder(&buf).Encode(&msg); err != nil {
  176. return "", err
  177. }
  178. return buf.String(), nil
  179. }
  180. return s.handleSubscribe(w, r, v, "json", "application/stream+json", encoder)
  181. }
  182. func (s *Server) handleSubscribeSSE(w http.ResponseWriter, r *http.Request, v *visitor) error {
  183. encoder := func(msg *message) (string, error) {
  184. var buf bytes.Buffer
  185. if err := json.NewEncoder(&buf).Encode(&msg); err != nil {
  186. return "", err
  187. }
  188. if msg.Event != messageEvent {
  189. return fmt.Sprintf("event: %s\ndata: %s\n", msg.Event, buf.String()), nil // Browser's .onmessage() does not fire on this!
  190. }
  191. return fmt.Sprintf("data: %s\n", buf.String()), nil
  192. }
  193. return s.handleSubscribe(w, r, v, "sse", "text/event-stream", encoder)
  194. }
  195. func (s *Server) handleSubscribeRaw(w http.ResponseWriter, r *http.Request, v *visitor) error {
  196. encoder := func(msg *message) (string, error) {
  197. if msg.Event == messageEvent { // only handle default events
  198. return strings.ReplaceAll(msg.Message, "\n", " ") + "\n", nil
  199. }
  200. return "\n", nil // "keepalive" and "open" events just send an empty line
  201. }
  202. return s.handleSubscribe(w, r, v, "raw", "text/plain", encoder)
  203. }
  204. func (s *Server) handleSubscribe(w http.ResponseWriter, r *http.Request, v *visitor, format string, contentType string, encoder messageEncoder) error {
  205. if err := v.AddSubscription(); err != nil {
  206. return errHTTPTooManyRequests
  207. }
  208. defer v.RemoveSubscription()
  209. t, err := s.topic(strings.TrimSuffix(r.URL.Path[1:], "/"+format)) // Hack
  210. if err != nil {
  211. return err
  212. }
  213. since, err := parseSince(r)
  214. if err != nil {
  215. return err
  216. }
  217. poll := r.URL.Query().Has("poll")
  218. sub := func(msg *message) error {
  219. m, err := encoder(msg)
  220. if err != nil {
  221. return err
  222. }
  223. if _, err := w.Write([]byte(m)); err != nil {
  224. return err
  225. }
  226. if fl, ok := w.(http.Flusher); ok {
  227. fl.Flush()
  228. }
  229. return nil
  230. }
  231. w.Header().Set("Access-Control-Allow-Origin", "*") // CORS, allow cross-origin requests
  232. w.Header().Set("Content-Type", contentType)
  233. if poll {
  234. return sendOldMessages(t, since, sub)
  235. }
  236. subscriberID := t.Subscribe(sub)
  237. defer t.Unsubscribe(subscriberID)
  238. if err := sub(newOpenMessage(t.id)); err != nil { // Send out open message
  239. return err
  240. }
  241. if err := sendOldMessages(t, since, sub); err != nil {
  242. return err
  243. }
  244. for {
  245. select {
  246. case <-t.ctx.Done():
  247. return nil
  248. case <-r.Context().Done():
  249. return nil
  250. case <-time.After(s.config.KeepaliveInterval):
  251. v.Keepalive()
  252. if err := sub(newKeepaliveMessage(t.id)); err != nil { // Send keepalive message
  253. return err
  254. }
  255. }
  256. }
  257. }
  258. func sendOldMessages(t *topic, since time.Time, sub subscriber) error {
  259. if since.IsZero() {
  260. return nil
  261. }
  262. for _, m := range t.Messages(since) {
  263. if err := sub(m); err != nil {
  264. return err
  265. }
  266. }
  267. return nil
  268. }
  269. func parseSince(r *http.Request) (time.Time, error) {
  270. if !r.URL.Query().Has("since") {
  271. return time.Time{}, nil
  272. }
  273. if since, err := strconv.ParseInt(r.URL.Query().Get("since"), 10, 64); err == nil {
  274. return time.Unix(since, 0), nil
  275. }
  276. if d, err := time.ParseDuration(r.URL.Query().Get("since")); err == nil {
  277. return time.Now().Add(-1 * d), nil
  278. }
  279. return time.Time{}, errHTTPBadRequest
  280. }
  281. func (s *Server) handleOptions(w http.ResponseWriter, r *http.Request) error {
  282. w.Header().Set("Access-Control-Allow-Origin", "*") // CORS, allow cross-origin requests
  283. w.Header().Set("Access-Control-Allow-Methods", "GET, PUT, POST")
  284. return nil
  285. }
  286. func (s *Server) topic(id string) (*topic, error) {
  287. s.mu.Lock()
  288. defer s.mu.Unlock()
  289. if _, ok := s.topics[id]; !ok {
  290. if len(s.topics) >= s.config.GlobalTopicLimit {
  291. return nil, errHTTPTooManyRequests
  292. }
  293. s.topics[id] = newTopic(id)
  294. if s.firebase != nil {
  295. s.topics[id].Subscribe(s.firebase)
  296. }
  297. }
  298. return s.topics[id], nil
  299. }
  300. func (s *Server) updateStatsAndExpire() {
  301. s.mu.Lock()
  302. defer s.mu.Unlock()
  303. // Expire visitors from rate visitors map
  304. for ip, v := range s.visitors {
  305. if v.Stale() {
  306. delete(s.visitors, ip)
  307. }
  308. }
  309. // Prune old messages, remove subscriptions without subscribers
  310. for _, t := range s.topics {
  311. t.Prune(s.config.MessageBufferDuration)
  312. subs, msgs := t.Stats()
  313. if msgs == 0 && (subs == 0 || (s.firebase != nil && subs == 1)) { // Firebase is a subscriber!
  314. delete(s.topics, t.id)
  315. }
  316. }
  317. // Print stats
  318. var subscribers, messages int
  319. for _, t := range s.topics {
  320. subs, msgs := t.Stats()
  321. subscribers += subs
  322. messages += msgs
  323. }
  324. log.Printf("Stats: %d message(s) published, %d topic(s) active, %d subscriber(s), %d message(s) buffered, %d visitor(s)",
  325. s.messages, len(s.topics), subscribers, messages, len(s.visitors))
  326. }
  327. // visitor creates or retrieves a rate.Limiter for the given visitor.
  328. // This function was taken from https://www.alexedwards.net/blog/how-to-rate-limit-http-requests (MIT).
  329. func (s *Server) visitor(remoteAddr string) *visitor {
  330. s.mu.Lock()
  331. defer s.mu.Unlock()
  332. ip, _, err := net.SplitHostPort(remoteAddr)
  333. if err != nil {
  334. ip = remoteAddr // This should not happen in real life; only in tests.
  335. }
  336. v, exists := s.visitors[ip]
  337. if !exists {
  338. s.visitors[ip] = newVisitor(s.config)
  339. return s.visitors[ip]
  340. }
  341. v.seen = time.Now()
  342. return v
  343. }
  344. func (s *Server) fail(w http.ResponseWriter, r *http.Request, code int, err error) {
  345. log.Printf("[%s] %s - %d - %s", r.RemoteAddr, r.Method, code, err.Error())
  346. w.WriteHeader(code)
  347. io.WriteString(w, fmt.Sprintf("%s\n", http.StatusText(code)))
  348. }