Browse Source

rate limiting impl 2.0?

Karmanyaah Malhotra 3 years ago
parent
commit
1655f584f9
4 changed files with 97 additions and 61 deletions
  1. 2 2
      server/errors.go
  2. 37 38
      server/server.go
  3. 37 19
      server/topic.go
  4. 21 2
      server/util.go

+ 2 - 2
server/errors.go

@@ -3,8 +3,9 @@ package server
 import (
 	"encoding/json"
 	"fmt"
-	"heckel.io/ntfy/log"
 	"net/http"
+
+	"heckel.io/ntfy/log"
 )
 
 // errHTTP is a generic HTTP error for any non-200 HTTP error
@@ -92,5 +93,4 @@ var (
 	errHTTPInternalError                             = &errHTTP{50001, http.StatusInternalServerError, "internal server error", ""}
 	errHTTPInternalErrorInvalidPath                  = &errHTTP{50002, http.StatusInternalServerError, "internal server error: invalid path", ""}
 	errHTTPInternalErrorMissingBaseURL               = &errHTTP{50003, http.StatusInternalServerError, "internal server error: base-url must be be configured for this feature", "https://ntfy.sh/docs/config/"}
-	errHTTPWontStoreMessage                          = &errHTTP{50701, http.StatusInsufficientStorage, "topic is inactive; no device available to recieve message", ""}
 )

+ 37 - 38
server/server.go

@@ -9,12 +9,6 @@ import (
 	"encoding/json"
 	"errors"
 	"fmt"
-	"github.com/emersion/go-smtp"
-	"github.com/gorilla/websocket"
-	"golang.org/x/sync/errgroup"
-	"heckel.io/ntfy/log"
-	"heckel.io/ntfy/user"
-	"heckel.io/ntfy/util"
 	"io"
 	"net"
 	"net/http"
@@ -30,6 +24,13 @@ import (
 	"sync"
 	"time"
 	"unicode/utf8"
+
+	"github.com/emersion/go-smtp"
+	"github.com/gorilla/websocket"
+	"golang.org/x/sync/errgroup"
+	"heckel.io/ntfy/log"
+	"heckel.io/ntfy/user"
+	"heckel.io/ntfy/util"
 )
 
 /*
@@ -605,23 +606,23 @@ func (s *Server) handlePublishWithoutResponse(r *http.Request, v *visitor) (*mes
 	if err != nil {
 		return nil, err
 	}
-	v_old := v
-	if strings.HasPrefix(t.ID, subscriberBilledTopicPrefix) {
-		v = t.getBillee()
-		if v == nil {
-			return nil, errHTTPWontStoreMessage
-		}
+	vRate := v
+	if topicCountsAgainst := t.Billee(); topicCountsAgainst != nil {
+		vRate = topicCountsAgainst
 	}
 
-	if !v.MessageAllowed() {
-		return nil, errHTTPTooManyRequestsLimitMessages
+	if !vRate.MessageAllowed() {
+		vRate = v
+		if !v.MessageAllowed() {
+			return nil, errHTTPTooManyRequestsLimitMessages
+		}
 	}
 	body, err := util.Peek(r.Body, s.config.MessageLimit)
 	if err != nil {
 		return nil, err
 	}
 	m := newDefaultMessage(t.ID, "")
-	cache, firebase, email, unifiedpush, err := s.parsePublishParams(r, v, m)
+	cache, firebase, email, unifiedpush, err := s.parsePublishParams(r, vRate, m)
 	if err != nil {
 		return nil, err
 	}
@@ -630,7 +631,7 @@ func (s *Server) handlePublishWithoutResponse(r *http.Request, v *visitor) (*mes
 	}
 	m.Sender = v.IP()
 	m.User = v.MaybeUserID()
-	m.Expires = time.Unix(m.Time, 0).Add(v.Limits().MessageExpiryDuration).Unix()
+	m.Expires = time.Unix(m.Time, 0).Add(vRate.Limits().MessageExpiryDuration).Unix()
 	if err := s.handlePublishBody(r, v, m, body, unifiedpush); err != nil {
 		return nil, err
 	}
@@ -638,18 +639,18 @@ func (s *Server) handlePublishWithoutResponse(r *http.Request, v *visitor) (*mes
 		m.Message = emptyMessageBody
 	}
 	delayed := m.Time > time.Now().Unix()
-	logvrm(v, r, m).
+	logvrm(vRate, r, m).
 		Tag(tagPublish).
 		Fields(log.Context{
-			"message_delayed":     delayed,
-			"message_firebase":    firebase,
-			"message_unifiedpush": unifiedpush,
-			"message_email":       email,
+			"message_delayed":                 delayed,
+			"message_firebase":                firebase,
+			"message_unifiedpush":             unifiedpush,
+			"message_email":                   email,
+			"message_subscriber_rate_limited": vRate != v,
 		}).
 		Debug("Received message")
-		//Where should I log the original visitor vs the billing visitor
 	if log.IsTrace() {
-		logvrm(v_old, r, m).
+		logvrm(vRate, r, m).
 			Tag(tagPublish).
 			Field("message_body", util.MaybeMarshalJSON(m)).
 			Trace("Message body")
@@ -659,7 +660,7 @@ func (s *Server) handlePublishWithoutResponse(r *http.Request, v *visitor) (*mes
 			return nil, err
 		}
 		if s.firebaseClient != nil && firebase {
-			go s.sendToFirebase(v, m)
+			go s.sendToFirebase(vRate, m)
 		}
 		if s.smtpSender != nil && email != "" {
 			go s.sendEmail(v, m, email)
@@ -745,7 +746,7 @@ func (s *Server) forwardPollRequest(v *visitor, m *message) {
 	}
 }
 
-func (s *Server) parsePublishParams(r *http.Request, v *visitor, m *message) (cache bool, firebase bool, email string, unifiedpush bool, err error) {
+func (s *Server) parsePublishParams(r *http.Request, vRate *visitor, m *message) (cache bool, firebase bool, email string, unifiedpush bool, err error) {
 	cache = readBoolParam(r, true, "x-cache", "cache")
 	firebase = readBoolParam(r, true, "x-firebase", "firebase")
 	m.Title = readParam(r, "x-title", "title", "t")
@@ -785,7 +786,7 @@ func (s *Server) parsePublishParams(r *http.Request, v *visitor, m *message) (ca
 	}
 	email = readParam(r, "x-email", "x-e-mail", "email", "e-mail", "mail", "e")
 	if email != "" {
-		if !v.EmailAllowed() {
+		if !vRate.EmailAllowed() {
 			return false, false, "", false, errHTTPTooManyRequestsLimitEmails
 		}
 	}
@@ -800,13 +801,7 @@ func (s *Server) parsePublishParams(r *http.Request, v *visitor, m *message) (ca
 	if err != nil {
 		return false, false, "", false, errHTTPBadRequestPriorityInvalid
 	}
-	tagsStr := readParam(r, "x-tags", "tags", "tag", "ta")
-	if tagsStr != "" {
-		m.Tags = make([]string, 0)
-		for _, s := range util.SplitNoEmpty(tagsStr, ",") {
-			m.Tags = append(m.Tags, strings.TrimSpace(s))
-		}
-	}
+	m.Tags = readCommaSeperatedParam(r, "x-tags", "tags", "tag", "ta")
 	delayStr := readParam(r, "x-delay", "delay", "x-at", "at", "x-in", "in")
 	if delayStr != "" {
 		if !cache {
@@ -996,7 +991,7 @@ func (s *Server) handleSubscribeHTTP(w http.ResponseWriter, r *http.Request, v *
 	if err != nil {
 		return err
 	}
-	poll, since, scheduled, filters, err := parseSubscribeParams(r)
+	poll, since, scheduled, filters, subscriberRateTopics, err := parseSubscribeParams(r)
 	if err != nil {
 		return err
 	}
@@ -1035,7 +1030,8 @@ func (s *Server) handleSubscribeHTTP(w http.ResponseWriter, r *http.Request, v *
 	defer cancel()
 	subscriberIDs := make([]int, 0)
 	for _, t := range topics {
-		subscriberIDs = append(subscriberIDs, t.Subscribe(sub, v, cancel))
+		subscriberRateLimited := util.Contains(subscriberRateTopics, t.ID) || strings.HasPrefix(t.ID, subscriberBilledTopicPrefix) // temporarily do prefix as well
+		subscriberIDs = append(subscriberIDs, t.Subscribe(sub, v, cancel, subscriberRateLimited))
 	}
 	defer func() {
 		for i, subscriberID := range subscriberIDs {
@@ -1078,7 +1074,7 @@ func (s *Server) handleSubscribeWS(w http.ResponseWriter, r *http.Request, v *vi
 	if err != nil {
 		return err
 	}
-	poll, since, scheduled, filters, err := parseSubscribeParams(r)
+	poll, since, scheduled, filters, subscriberRateTopics, err := parseSubscribeParams(r)
 	if err != nil {
 		return err
 	}
@@ -1167,7 +1163,8 @@ func (s *Server) handleSubscribeWS(w http.ResponseWriter, r *http.Request, v *vi
 	}
 	subscriberIDs := make([]int, 0)
 	for _, t := range topics {
-		subscriberIDs = append(subscriberIDs, t.Subscribe(sub, v, cancel))
+		subscriberRateLimited := util.Contains(subscriberRateTopics, t.ID) || strings.HasPrefix(t.ID, subscriberBilledTopicPrefix) // temporarily do prefix as well
+		subscriberIDs = append(subscriberIDs, t.Subscribe(sub, v, cancel, subscriberRateLimited))
 	}
 	defer func() {
 		for i, subscriberID := range subscriberIDs {
@@ -1188,7 +1185,7 @@ func (s *Server) handleSubscribeWS(w http.ResponseWriter, r *http.Request, v *vi
 	return err
 }
 
-func parseSubscribeParams(r *http.Request) (poll bool, since sinceMarker, scheduled bool, filters *queryFilter, err error) {
+func parseSubscribeParams(r *http.Request) (poll bool, since sinceMarker, scheduled bool, filters *queryFilter, subscriberTopics []string, err error) {
 	poll = readBoolParam(r, false, "x-poll", "poll", "po")
 	scheduled = readBoolParam(r, false, "x-scheduled", "scheduled", "sched")
 	since, err = parseSince(r, poll)
@@ -1199,6 +1196,8 @@ func parseSubscribeParams(r *http.Request) (poll bool, since sinceMarker, schedu
 	if err != nil {
 		return
 	}
+
+	subscriberTopics = readCommaSeperatedParam(r, "subscriber-rate-limit-topics", "x-subscriber-rate-limit-topics", "srlt")
 	return
 }
 

+ 37 - 19
server/topic.go

@@ -19,9 +19,10 @@ type topic struct {
 }
 
 type topicSubscriber struct {
-	subscriber subscriber
-	visitor    *visitor // User ID associated with this subscription, may be empty
-	cancel     func()
+	subscriber          subscriber
+	visitor             *visitor // User ID associated with this subscription, may be empty
+	cancel              func()
+	subscriberRateLimit bool
 }
 
 // subscriber is a function that is called for every new message on a topic
@@ -36,31 +37,36 @@ func newTopic(id string) *topic {
 }
 
 // Subscribe subscribes to this topic
-func (t *topic) Subscribe(s subscriber, visitor *visitor, cancel func()) int {
+func (t *topic) Subscribe(s subscriber, visitor *visitor, cancel func(), subscriberRateLimit bool) int {
 	t.mu.Lock()
 	defer t.mu.Unlock()
 	subscriberID := rand.Int()
 	t.subscribers[subscriberID] = &topicSubscriber{
-		visitor:    visitor, // May be empty
-		subscriber: s,
-		cancel:     cancel,
+		visitor:             visitor, // May be empty
+		subscriber:          s,
+		cancel:              cancel,
+		subscriberRateLimit: subscriberRateLimit,
 	}
+
+	// if no subscriber is already handling the rate limit
+	if t.lastVisitor == nil && subscriberRateLimit {
+		t.lastVisitor = visitor
+		t.lastVisitorExpires = time.Time{}
+	}
+
 	return subscriberID
 }
 
 func (t *topic) Stale() bool {
-	return t.getBillee() == nil
-}
-
-func (t *topic) getBillee() *visitor {
-	for _, this_subscriber := range t.subscribers {
-		return this_subscriber.visitor
-	}
-	if t.lastVisitor != nil && t.lastVisitorExpires.After(time.Now()) {
+	// if Time is initialized (not the zero value) and the expiry time has passed
+	if !t.lastVisitorExpires.IsZero() && t.lastVisitorExpires.Before(time.Now()) {
 		t.lastVisitor = nil
 	}
-	return t.lastVisitor
+	return len(t.subscribers) == 0 && t.lastVisitor == nil
+}
 
+func (t *topic) Billee() *visitor {
+	return t.lastVisitor
 }
 
 // Unsubscribe removes the subscription from the list of subscribers
@@ -68,11 +74,23 @@ func (t *topic) Unsubscribe(id int) {
 	t.mu.Lock()
 	defer t.mu.Unlock()
 
-	if len(t.subscribers) == 1 {
-		t.lastVisitor = t.subscribers[id].visitor
+	deletingSub := t.subscribers[id]
+	delete(t.subscribers, id)
+
+	// look for an active subscriber (in random order) that wants to handle the rate limit
+	for _, v := range t.subscribers {
+		if v.subscriberRateLimit {
+			t.lastVisitor = v.visitor
+			t.lastVisitorExpires = time.Time{}
+			return
+		}
+	}
+
+	// if no active subscriber is found, count it towards the leaving subscriber
+	if deletingSub.subscriberRateLimit {
+		t.lastVisitor = deletingSub.visitor
 		t.lastVisitorExpires = time.Now().Add(subscriberBilledValidity)
 	}
-	delete(t.subscribers, id)
 }
 
 // Publish asynchronously publishes to all subscribers

+ 21 - 2
server/util.go

@@ -1,12 +1,13 @@
 package server
 
 import (
-	"heckel.io/ntfy/log"
-	"heckel.io/ntfy/util"
 	"io"
 	"net/http"
 	"net/netip"
 	"strings"
+
+	"heckel.io/ntfy/log"
+	"heckel.io/ntfy/util"
 )
 
 func readBoolParam(r *http.Request, defaultValue bool, names ...string) bool {
@@ -17,6 +18,17 @@ func readBoolParam(r *http.Request, defaultValue bool, names ...string) bool {
 	return value == "1" || value == "yes" || value == "true"
 }
 
+func readCommaSeperatedParam(r *http.Request, names ...string) (params []string) {
+	paramStr := readParam(r, names...)
+	if paramStr != "" {
+		params = make([]string, 0)
+		for _, s := range util.SplitNoEmpty(paramStr, ",") {
+			params = append(params, strings.TrimSpace(s))
+		}
+	}
+	return params
+}
+
 func readParam(r *http.Request, names ...string) string {
 	value := readHeaderParam(r, names...)
 	if value != "" {
@@ -35,6 +47,13 @@ func readHeaderParam(r *http.Request, names ...string) string {
 	return ""
 }
 
+func readHeaderParamValues(r *http.Request, names ...string) (values []string) {
+	for _, name := range names {
+		values = append(values, r.Header.Values(name)...)
+	}
+	return
+}
+
 func readQueryParam(r *http.Request, names ...string) string {
 	for _, name := range names {
 		value := r.URL.Query().Get(strings.ToLower(name))