server_middleware.go 2.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104
  1. package server
  2. import (
  3. "net/http"
  4. "heckel.io/ntfy/util"
  5. )
  6. type contextKey int
  7. const (
  8. contextRateVisitor contextKey = iota + 2586
  9. contextTopic
  10. )
  11. func (s *Server) limitRequests(next handleFunc) handleFunc {
  12. return func(w http.ResponseWriter, r *http.Request, v *visitor) error {
  13. if util.ContainsIP(s.config.VisitorRequestExemptIPAddrs, v.ip) {
  14. return next(w, r, v)
  15. } else if !v.RequestAllowed() {
  16. return errHTTPTooManyRequestsLimitRequests
  17. }
  18. return next(w, r, v)
  19. }
  20. }
  21. // limitRequestsWithTopic limits requests with a topic and stores the rate-limiting-subscriber and topic into request.Context
  22. func (s *Server) limitRequestsWithTopic(next handleFunc) handleFunc {
  23. return func(w http.ResponseWriter, r *http.Request, v *visitor) error {
  24. t, err := s.topicFromPath(r.URL.Path)
  25. if err != nil {
  26. return err
  27. }
  28. vrate := v
  29. if rateVisitor := t.RateVisitor(); rateVisitor != nil {
  30. vrate = rateVisitor
  31. }
  32. r = withContext(r, map[contextKey]any{
  33. contextRateVisitor: vrate,
  34. contextTopic: t,
  35. })
  36. if util.ContainsIP(s.config.VisitorRequestExemptIPAddrs, v.ip) {
  37. return next(w, r, v)
  38. } else if !vrate.RequestAllowed() {
  39. return errHTTPTooManyRequestsLimitRequests
  40. }
  41. return next(w, r, v)
  42. }
  43. }
  44. func (s *Server) ensureWebEnabled(next handleFunc) handleFunc {
  45. return func(w http.ResponseWriter, r *http.Request, v *visitor) error {
  46. if !s.config.EnableWeb {
  47. return errHTTPNotFound
  48. }
  49. return next(w, r, v)
  50. }
  51. }
  52. func (s *Server) ensureUserManager(next handleFunc) handleFunc {
  53. return func(w http.ResponseWriter, r *http.Request, v *visitor) error {
  54. if s.userManager == nil {
  55. return errHTTPNotFound
  56. }
  57. return next(w, r, v)
  58. }
  59. }
  60. func (s *Server) ensureUser(next handleFunc) handleFunc {
  61. return s.ensureUserManager(func(w http.ResponseWriter, r *http.Request, v *visitor) error {
  62. if v.User() == nil {
  63. return errHTTPUnauthorized
  64. }
  65. return next(w, r, v)
  66. })
  67. }
  68. func (s *Server) ensurePaymentsEnabled(next handleFunc) handleFunc {
  69. return func(w http.ResponseWriter, r *http.Request, v *visitor) error {
  70. if s.config.StripeSecretKey == "" || s.stripe == nil {
  71. return errHTTPNotFound
  72. }
  73. return next(w, r, v)
  74. }
  75. }
  76. func (s *Server) ensureStripeCustomer(next handleFunc) handleFunc {
  77. return s.ensureUser(func(w http.ResponseWriter, r *http.Request, v *visitor) error {
  78. if v.User().Billing.StripeCustomerID == "" {
  79. return errHTTPBadRequestNotAPaidUser
  80. }
  81. return next(w, r, v)
  82. })
  83. }
  84. func (s *Server) withAccountSync(next handleFunc) handleFunc {
  85. return func(w http.ResponseWriter, r *http.Request, v *visitor) error {
  86. err := next(w, r, v)
  87. if err == nil {
  88. s.publishSyncEventAsync(v)
  89. }
  90. return err
  91. }
  92. }