server_middleware.go 2.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596
  1. package server
  2. import (
  3. "context"
  4. "net/http"
  5. "heckel.io/ntfy/util"
  6. )
  7. func (s *Server) limitRequests(next handleFunc) handleFunc {
  8. return func(w http.ResponseWriter, r *http.Request, v *visitor) error {
  9. if util.ContainsIP(s.config.VisitorRequestExemptIPAddrs, v.ip) {
  10. return next(w, r, v)
  11. } else if !v.RequestAllowed() {
  12. return errHTTPTooManyRequestsLimitRequests
  13. }
  14. return next(w, r, v)
  15. }
  16. }
  17. // limitRequestsWithTopic limits requests with a topic and stores the rate-limiting-subscriber and topic into request.Context
  18. func (s *Server) limitRequestsWithTopic(next handleFunc) handleFunc {
  19. return func(w http.ResponseWriter, r *http.Request, v *visitor) error {
  20. t, err := s.topicFromPath(r.URL.Path)
  21. if err != nil {
  22. return err
  23. }
  24. vRate := v
  25. if topicCountsAgainst := t.Billee(); topicCountsAgainst != nil {
  26. vRate = topicCountsAgainst
  27. }
  28. r.WithContext(context.WithValue(context.WithValue(r.Context(), "vRate", vRate), "topic", t))
  29. if util.ContainsIP(s.config.VisitorRequestExemptIPAddrs, v.ip) {
  30. return next(w, r, v)
  31. } else if !vRate.RequestAllowed() {
  32. return errHTTPTooManyRequestsLimitRequests
  33. }
  34. return next(w, r, v)
  35. }
  36. }
  37. func (s *Server) ensureWebEnabled(next handleFunc) handleFunc {
  38. return func(w http.ResponseWriter, r *http.Request, v *visitor) error {
  39. if !s.config.EnableWeb {
  40. return errHTTPNotFound
  41. }
  42. return next(w, r, v)
  43. }
  44. }
  45. func (s *Server) ensureUserManager(next handleFunc) handleFunc {
  46. return func(w http.ResponseWriter, r *http.Request, v *visitor) error {
  47. if s.userManager == nil {
  48. return errHTTPNotFound
  49. }
  50. return next(w, r, v)
  51. }
  52. }
  53. func (s *Server) ensureUser(next handleFunc) handleFunc {
  54. return s.ensureUserManager(func(w http.ResponseWriter, r *http.Request, v *visitor) error {
  55. if v.User() == nil {
  56. return errHTTPUnauthorized
  57. }
  58. return next(w, r, v)
  59. })
  60. }
  61. func (s *Server) ensurePaymentsEnabled(next handleFunc) handleFunc {
  62. return func(w http.ResponseWriter, r *http.Request, v *visitor) error {
  63. if s.config.StripeSecretKey == "" || s.stripe == nil {
  64. return errHTTPNotFound
  65. }
  66. return next(w, r, v)
  67. }
  68. }
  69. func (s *Server) ensureStripeCustomer(next handleFunc) handleFunc {
  70. return s.ensureUser(func(w http.ResponseWriter, r *http.Request, v *visitor) error {
  71. if v.User().Billing.StripeCustomerID == "" {
  72. return errHTTPBadRequestNotAPaidUser
  73. }
  74. return next(w, r, v)
  75. })
  76. }
  77. func (s *Server) withAccountSync(next handleFunc) handleFunc {
  78. return func(w http.ResponseWriter, r *http.Request, v *visitor) error {
  79. err := next(w, r, v)
  80. if err == nil {
  81. s.publishSyncEventAsync(v)
  82. }
  83. return err
  84. }
  85. }