Ver Fonte

Polish a little

binwiederhier há 3 anos atrás
pai
commit
bdeec4d297
5 ficheiros alterados com 63 adições e 40 exclusões
  1. 38 29
      server/server.go
  2. 11 3
      server/server_middleware.go
  3. 1 1
      server/server_test.go
  4. 4 0
      server/topic.go
  5. 9 7
      server/util.go

+ 38 - 29
server/server.go

@@ -104,15 +104,15 @@ var (
 )
 )
 
 
 const (
 const (
-	firebaseControlTopic      = "~control"                // See Android if changed
-	firebasePollTopic         = "~poll"                   // See iOS if changed
-	emptyMessageBody          = "triggered"               // Used if message body is empty
-	newMessageBody            = "New message"             // Used in poll requests as generic message
-	defaultAttachmentMessage  = "You received a file: %s" // Used if message body is empty, and there is an attachment
-	encodingBase64            = "base64"                  // Used mainly for binary UnifiedPush messages
-	jsonBodyBytesLimit        = 16384
-	unifiedPushTopicPrefix    = "up" // Temporarily, we rate limit all "up*" topics based on the subscriber
-	rateVisitorExpiryDuration = 12 * time.Hour
+	firebaseControlTopic     = "~control"                // See Android if changed
+	firebasePollTopic        = "~poll"                   // See iOS if changed
+	emptyMessageBody         = "triggered"               // Used if message body is empty
+	newMessageBody           = "New message"             // Used in poll requests as generic message
+	defaultAttachmentMessage = "You received a file: %s" // Used if message body is empty, and there is an attachment
+	encodingBase64           = "base64"                  // Used mainly for binary UnifiedPush messages
+	jsonBodyBytesLimit       = 16384
+	unifiedPushTopicPrefix   = "up" // Temporarily, we rate limit all "up*" topics based on the subscriber
+	rateTopicsWildcard       = "*"  // Allows defining all topics in the request subscriber-rate-limited topics
 )
 )
 
 
 // WebSocket constants
 // WebSocket constants
@@ -571,11 +571,11 @@ func (s *Server) handleMatrixDiscovery(w http.ResponseWriter) error {
 }
 }
 
 
 func (s *Server) handlePublishWithoutResponse(r *http.Request, v *visitor) (*message, error) {
 func (s *Server) handlePublishWithoutResponse(r *http.Request, v *visitor) (*message, error) {
-	vrate, ok := r.Context().Value("vRate").(*visitor)
+	vrate, ok := r.Context().Value(contextRateVisitor).(*visitor)
 	if !ok {
 	if !ok {
 		return nil, errHTTPInternalError
 		return nil, errHTTPInternalError
 	}
 	}
-	t, ok := r.Context().Value("topic").(*topic)
+	t, ok := r.Context().Value(contextTopic).(*topic)
 	if !ok {
 	if !ok {
 		return nil, errHTTPInternalError
 		return nil, errHTTPInternalError
 	}
 	}
@@ -709,7 +709,7 @@ func (s *Server) forwardPollRequest(v *visitor, m *message) {
 	}
 	}
 }
 }
 
 
-func (s *Server) parsePublishParams(r *http.Request, vRate *visitor, m *message) (cache bool, firebase bool, email string, unifiedpush bool, err error) {
+func (s *Server) parsePublishParams(r *http.Request, vrate *visitor, m *message) (cache bool, firebase bool, email string, unifiedpush bool, err error) {
 	cache = readBoolParam(r, true, "x-cache", "cache")
 	cache = readBoolParam(r, true, "x-cache", "cache")
 	firebase = readBoolParam(r, true, "x-firebase", "firebase")
 	firebase = readBoolParam(r, true, "x-firebase", "firebase")
 	m.Title = readParam(r, "x-title", "title", "t")
 	m.Title = readParam(r, "x-title", "title", "t")
@@ -749,7 +749,7 @@ func (s *Server) parsePublishParams(r *http.Request, vRate *visitor, m *message)
 	}
 	}
 	email = readParam(r, "x-email", "x-e-mail", "email", "e-mail", "mail", "e")
 	email = readParam(r, "x-email", "x-e-mail", "email", "e-mail", "mail", "e")
 	if email != "" {
 	if email != "" {
-		if !vRate.EmailAllowed() {
+		if !vrate.EmailAllowed() {
 			return false, false, "", false, errHTTPTooManyRequestsLimitEmails
 			return false, false, "", false, errHTTPTooManyRequestsLimitEmails
 		}
 		}
 	}
 	}
@@ -954,7 +954,7 @@ func (s *Server) handleSubscribeHTTP(w http.ResponseWriter, r *http.Request, v *
 	if err != nil {
 	if err != nil {
 		return err
 		return err
 	}
 	}
-	poll, since, scheduled, filters, subscriberRateTopics, err := parseSubscribeParams(r)
+	poll, since, scheduled, filters, rateTopics, err := parseSubscribeParams(r)
 	if err != nil {
 	if err != nil {
 		return err
 		return err
 	}
 	}
@@ -984,12 +984,7 @@ func (s *Server) handleSubscribeHTTP(w http.ResponseWriter, r *http.Request, v *
 		}
 		}
 		return nil
 		return nil
 	}
 	}
-	for _, t := range topics {
-		subscriberRateLimited := util.Contains(subscriberRateTopics, t.ID) || strings.HasPrefix(t.ID, unifiedPushTopicPrefix) // temporarily do prefix as well
-		if subscriberRateLimited {
-			t.SetRateVisitor(v)
-		}
-	}
+	registerRateVisitors(topics, rateTopics, v)
 	w.Header().Set("Access-Control-Allow-Origin", s.config.AccessControlAllowOrigin) // CORS, allow cross-origin requests
 	w.Header().Set("Access-Control-Allow-Origin", s.config.AccessControlAllowOrigin) // CORS, allow cross-origin requests
 	w.Header().Set("Content-Type", contentType+"; charset=utf-8")                    // Android/Volley client needs charset!
 	w.Header().Set("Content-Type", contentType+"; charset=utf-8")                    // Android/Volley client needs charset!
 	if poll {
 	if poll {
@@ -1042,7 +1037,7 @@ func (s *Server) handleSubscribeWS(w http.ResponseWriter, r *http.Request, v *vi
 	if err != nil {
 	if err != nil {
 		return err
 		return err
 	}
 	}
-	poll, since, scheduled, filters, subscriberRateTopics, err := parseSubscribeParams(r)
+	poll, since, scheduled, filters, rateTopics, err := parseSubscribeParams(r)
 	if err != nil {
 	if err != nil {
 		return err
 		return err
 	}
 	}
@@ -1125,12 +1120,7 @@ func (s *Server) handleSubscribeWS(w http.ResponseWriter, r *http.Request, v *vi
 		}
 		}
 		return conn.WriteJSON(msg)
 		return conn.WriteJSON(msg)
 	}
 	}
-	for _, t := range topics {
-		subscriberRateLimited := util.Contains(subscriberRateTopics, t.ID) || strings.HasPrefix(t.ID, unifiedPushTopicPrefix) // temporarily do prefix as well
-		if subscriberRateLimited {
-			t.SetRateVisitor(v)
-		}
-	}
+	registerRateVisitors(topics, rateTopics, v)
 	w.Header().Set("Access-Control-Allow-Origin", s.config.AccessControlAllowOrigin) // CORS, allow cross-origin requests
 	w.Header().Set("Access-Control-Allow-Origin", s.config.AccessControlAllowOrigin) // CORS, allow cross-origin requests
 	if poll {
 	if poll {
 		return s.sendOldMessages(topics, since, scheduled, v, sub)
 		return s.sendOldMessages(topics, since, scheduled, v, sub)
@@ -1158,7 +1148,7 @@ func (s *Server) handleSubscribeWS(w http.ResponseWriter, r *http.Request, v *vi
 	return err
 	return err
 }
 }
 
 
-func parseSubscribeParams(r *http.Request) (poll bool, since sinceMarker, scheduled bool, filters *queryFilter, subscriberTopics []string, err error) {
+func parseSubscribeParams(r *http.Request) (poll bool, since sinceMarker, scheduled bool, filters *queryFilter, rateTopics []string, err error) {
 	poll = readBoolParam(r, false, "x-poll", "poll", "po")
 	poll = readBoolParam(r, false, "x-poll", "poll", "po")
 	scheduled = readBoolParam(r, false, "x-scheduled", "scheduled", "sched")
 	scheduled = readBoolParam(r, false, "x-scheduled", "scheduled", "sched")
 	since, err = parseSince(r, poll)
 	since, err = parseSince(r, poll)
@@ -1169,10 +1159,29 @@ func parseSubscribeParams(r *http.Request) (poll bool, since sinceMarker, schedu
 	if err != nil {
 	if err != nil {
 		return
 		return
 	}
 	}
-	subscriberTopics = readCommaSeparatedParam(r, "subscriber-rate-limit-topics", "x-subscriber-rate-limit-topics", "srlt")
+	rateTopics = readCommaSeparatedParam(r, "x-rate-topics", "rate-topics")
 	return
 	return
 }
 }
 
 
+// registerRateVisitors sets the rate visitor on a topic, indicating that all messages published to that topic
+// will be rate limited against the rate visitor instead of the publishing visitor.
+//
+// Note: This TEMPORARILY also registers all topics starting with "up" (= UnifiedPush). This is to ease the transition
+// until the Android app will send the "Rate-Topics" header.
+func registerRateVisitors(topics []*topic, rateTopics []string, v *visitor) {
+	if len(rateTopics) > 0 && rateTopics[0] == rateTopicsWildcard {
+		for _, t := range topics {
+			t.SetRateVisitor(v)
+		}
+	} else {
+		for _, t := range topics {
+			if util.Contains(rateTopics, t.ID) || strings.HasPrefix(t.ID, unifiedPushTopicPrefix) {
+				t.SetRateVisitor(v)
+			}
+		}
+	}
+}
+
 // sendOldMessages selects old messages from the messageCache and calls sub for each of them. It uses since as the
 // sendOldMessages selects old messages from the messageCache and calls sub for each of them. It uses since as the
 // marker, returning only messages that are newer than the marker.
 // marker, returning only messages that are newer than the marker.
 func (s *Server) sendOldMessages(topics []*topic, since sinceMarker, scheduled bool, v *visitor, sub subscriber) error {
 func (s *Server) sendOldMessages(topics []*topic, since sinceMarker, scheduled bool, v *visitor, sub subscriber) error {

+ 11 - 3
server/server_middleware.go

@@ -1,12 +1,18 @@
 package server
 package server
 
 
 import (
 import (
-	"context"
 	"net/http"
 	"net/http"
 
 
 	"heckel.io/ntfy/util"
 	"heckel.io/ntfy/util"
 )
 )
 
 
+type contextKey int
+
+const (
+	contextRateVisitor contextKey = iota + 2586
+	contextTopic
+)
+
 func (s *Server) limitRequests(next handleFunc) handleFunc {
 func (s *Server) limitRequests(next handleFunc) handleFunc {
 	return func(w http.ResponseWriter, r *http.Request, v *visitor) error {
 	return func(w http.ResponseWriter, r *http.Request, v *visitor) error {
 		if util.ContainsIP(s.config.VisitorRequestExemptIPAddrs, v.ip) {
 		if util.ContainsIP(s.config.VisitorRequestExemptIPAddrs, v.ip) {
@@ -29,8 +35,10 @@ func (s *Server) limitRequestsWithTopic(next handleFunc) handleFunc {
 		if rateVisitor := t.RateVisitor(); rateVisitor != nil {
 		if rateVisitor := t.RateVisitor(); rateVisitor != nil {
 			vrate = rateVisitor
 			vrate = rateVisitor
 		}
 		}
-		r = r.WithContext(context.WithValue(context.WithValue(r.Context(), "vRate", vrate), "topic", t))
-
+		r = withContext(r, map[contextKey]any{
+			contextRateVisitor: vrate,
+			contextTopic:       t,
+		})
 		if util.ContainsIP(s.config.VisitorRequestExemptIPAddrs, v.ip) {
 		if util.ContainsIP(s.config.VisitorRequestExemptIPAddrs, v.ip) {
 			return next(w, r, v)
 			return next(w, r, v)
 		} else if !vrate.RequestAllowed() {
 		} else if !vrate.RequestAllowed() {

+ 1 - 1
server/server_test.go

@@ -1899,7 +1899,7 @@ func TestServer_SubscriberRateLimiting(t *testing.T) {
 		r.RemoteAddr = "1.2.3.4"
 		r.RemoteAddr = "1.2.3.4"
 	}
 	}
 	rr := request(t, s, "GET", "/subscriber1topic/json?poll=1", "", map[string]string{
 	rr := request(t, s, "GET", "/subscriber1topic/json?poll=1", "", map[string]string{
-		"Subscriber-Rate-Limit-Topics": "subscriber1topic",
+		"Rate-Topics": "subscriber1topic",
 	}, subscriber1Fn)
 	}, subscriber1Fn)
 	require.Equal(t, 200, rr.Code)
 	require.Equal(t, 200, rr.Code)
 	require.Equal(t, "", rr.Body.String())
 	require.Equal(t, "", rr.Body.String())

+ 4 - 0
server/topic.go

@@ -8,6 +8,10 @@ import (
 	"heckel.io/ntfy/log"
 	"heckel.io/ntfy/log"
 )
 )
 
 
+const (
+	rateVisitorExpiryDuration = 12 * time.Hour
+)
+
 // topic represents a channel to which subscribers can subscribe, and publishers
 // topic represents a channel to which subscribers can subscribe, and publishers
 // can publish a message
 // can publish a message
 type topic struct {
 type topic struct {

+ 9 - 7
server/util.go

@@ -1,6 +1,7 @@
 package server
 package server
 
 
 import (
 import (
+	"context"
 	"heckel.io/ntfy/util"
 	"heckel.io/ntfy/util"
 	"io"
 	"io"
 	"net/http"
 	"net/http"
@@ -45,13 +46,6 @@ func readHeaderParam(r *http.Request, names ...string) string {
 	return ""
 	return ""
 }
 }
 
 
-func readHeaderParamValues(r *http.Request, names ...string) (values []string) {
-	for _, name := range names {
-		values = append(values, r.Header.Values(name)...)
-	}
-	return
-}
-
 func readQueryParam(r *http.Request, names ...string) string {
 func readQueryParam(r *http.Request, names ...string) string {
 	for _, name := range names {
 	for _, name := range names {
 		value := r.URL.Query().Get(strings.ToLower(name))
 		value := r.URL.Query().Get(strings.ToLower(name))
@@ -103,3 +97,11 @@ func readJSONWithLimit[T any](r io.ReadCloser, limit int, allowEmpty bool) (*T,
 	}
 	}
 	return obj, nil
 	return obj, nil
 }
 }
+
+func withContext(r *http.Request, ctx map[contextKey]any) *http.Request {
+	c := r.Context()
+	for k, v := range ctx {
+		c = context.WithValue(c, k, v)
+	}
+	return r.WithContext(c)
+}