Ver Fonte

Subscription limit

Philipp Heckel há 4 anos atrás
pai
commit
fa7a45902f
3 ficheiros alterados com 97 adições e 37 exclusões
  1. 9 6
      config/config.go
  2. 23 31
      server/server.go
  3. 65 0
      server/visitor.go

+ 9 - 6
config/config.go

@@ -17,8 +17,9 @@ const (
 // Defines the max number of requests, here:
 // 50 requests bucket, replenished at a rate of 1 per second
 var (
-	defaultLimit      = rate.Every(time.Second)
-	defaultLimitBurst = 50
+	defaultRequestLimit      = rate.Every(time.Second)
+	defaultRequestLimitBurst = 50
+	defaultSubscriptionLimit = 30 // per visitor
 )
 
 // Config is the main config struct for the application. Use New to instantiate a default config struct.
@@ -28,8 +29,9 @@ type Config struct {
 	MessageBufferDuration time.Duration
 	KeepaliveInterval     time.Duration
 	ManagerInterval       time.Duration
-	Limit                 rate.Limit
-	LimitBurst            int
+	RequestLimit          rate.Limit
+	RequestLimitBurst     int
+	SubscriptionLimit     int
 }
 
 // New instantiates a default new config
@@ -40,7 +42,8 @@ func New(listenHTTP string) *Config {
 		MessageBufferDuration: DefaultMessageBufferDuration,
 		KeepaliveInterval:     DefaultKeepaliveInterval,
 		ManagerInterval:       DefaultManagerInterval,
-		Limit:                 defaultLimit,
-		LimitBurst:            defaultLimitBurst,
+		RequestLimit:          defaultRequestLimit,
+		RequestLimitBurst:     defaultRequestLimitBurst,
+		SubscriptionLimit:     defaultSubscriptionLimit,
 	}
 }

+ 23 - 31
server/server.go

@@ -9,7 +9,6 @@ import (
 	firebase "firebase.google.com/go"
 	"firebase.google.com/go/messaging"
 	"fmt"
-	"golang.org/x/time/rate"
 	"google.golang.org/api/option"
 	"heckel.io/ntfy/config"
 	"io"
@@ -23,9 +22,8 @@ import (
 	"time"
 )
 
-// TODO add "max connections open" limit
 // TODO add "max messages in a topic" limit
-// TODO add "max topics" limit
+// TODO implement persistence
 
 // Server is the main server
 type Server struct {
@@ -37,12 +35,6 @@ type Server struct {
 	mu       sync.Mutex
 }
 
-// visitor represents an API user, and its associated rate.Limiter used for rate limiting
-type visitor struct {
-	limiter *rate.Limiter
-	seen    time.Time
-}
-
 // errHTTP is a generic HTTP error for any non-200 HTTP error
 type errHTTP struct {
 	Code   int
@@ -54,8 +46,7 @@ func (e errHTTP) Error() string {
 }
 
 const (
-	messageLimit        = 1024
-	visitorExpungeAfter = 30 * time.Minute
+	messageLimit = 1024
 )
 
 var (
@@ -147,8 +138,8 @@ func (s *Server) handle(w http.ResponseWriter, r *http.Request) {
 
 func (s *Server) handleInternal(w http.ResponseWriter, r *http.Request) error {
 	v := s.visitor(r.RemoteAddr)
-	if !v.limiter.Allow() {
-		return errHTTPTooManyRequests
+	if err := v.RequestAllowed(); err != nil {
+		return err
 	}
 	if r.Method == http.MethodGet && r.URL.Path == "/" {
 		return s.handleHome(w, r)
@@ -157,11 +148,11 @@ func (s *Server) handleInternal(w http.ResponseWriter, r *http.Request) error {
 	} else if (r.Method == http.MethodPut || r.Method == http.MethodPost) && topicRegex.MatchString(r.URL.Path) {
 		return s.handlePublish(w, r)
 	} else if r.Method == http.MethodGet && jsonRegex.MatchString(r.URL.Path) {
-		return s.handleSubscribeJSON(w, r)
+		return s.handleSubscribeJSON(w, r, v)
 	} else if r.Method == http.MethodGet && sseRegex.MatchString(r.URL.Path) {
-		return s.handleSubscribeSSE(w, r)
+		return s.handleSubscribeSSE(w, r, v)
 	} else if r.Method == http.MethodGet && rawRegex.MatchString(r.URL.Path) {
-		return s.handleSubscribeRaw(w, r)
+		return s.handleSubscribeRaw(w, r, v)
 	} else if r.Method == http.MethodOptions {
 		return s.handleOptions(w, r)
 	}
@@ -195,7 +186,7 @@ func (s *Server) handlePublish(w http.ResponseWriter, r *http.Request) error {
 	return nil
 }
 
-func (s *Server) handleSubscribeJSON(w http.ResponseWriter, r *http.Request) error {
+func (s *Server) handleSubscribeJSON(w http.ResponseWriter, r *http.Request, v *visitor) error {
 	encoder := func(msg *message) (string, error) {
 		var buf bytes.Buffer
 		if err := json.NewEncoder(&buf).Encode(&msg); err != nil {
@@ -203,10 +194,10 @@ func (s *Server) handleSubscribeJSON(w http.ResponseWriter, r *http.Request) err
 		}
 		return buf.String(), nil
 	}
-	return s.handleSubscribe(w, r, "json", "application/stream+json", encoder)
+	return s.handleSubscribe(w, r, v, "json", "application/stream+json", encoder)
 }
 
-func (s *Server) handleSubscribeSSE(w http.ResponseWriter, r *http.Request) error {
+func (s *Server) handleSubscribeSSE(w http.ResponseWriter, r *http.Request, v *visitor) error {
 	encoder := func(msg *message) (string, error) {
 		var buf bytes.Buffer
 		if err := json.NewEncoder(&buf).Encode(&msg); err != nil {
@@ -217,20 +208,24 @@ func (s *Server) handleSubscribeSSE(w http.ResponseWriter, r *http.Request) erro
 		}
 		return fmt.Sprintf("data: %s\n", buf.String()), nil
 	}
-	return s.handleSubscribe(w, r, "sse", "text/event-stream", encoder)
+	return s.handleSubscribe(w, r, v, "sse", "text/event-stream", encoder)
 }
 
-func (s *Server) handleSubscribeRaw(w http.ResponseWriter, r *http.Request) error {
+func (s *Server) handleSubscribeRaw(w http.ResponseWriter, r *http.Request, v *visitor) error {
 	encoder := func(msg *message) (string, error) {
 		if msg.Event == "" { // only handle default events
 			return strings.ReplaceAll(msg.Message, "\n", " ") + "\n", nil
 		}
 		return "\n", nil // "keepalive" and "open" events just send an empty line
 	}
-	return s.handleSubscribe(w, r, "raw", "text/plain", encoder)
+	return s.handleSubscribe(w, r, v, "raw", "text/plain", encoder)
 }
 
-func (s *Server) handleSubscribe(w http.ResponseWriter, r *http.Request, format string, contentType string, encoder messageEncoder) error {
+func (s *Server) handleSubscribe(w http.ResponseWriter, r *http.Request, v *visitor, format string, contentType string, encoder messageEncoder) error {
+	if err := v.AddSubscription(); err != nil {
+		return err
+	}
+	defer v.RemoveSubscription()
 	t := s.createTopic(strings.TrimSuffix(r.URL.Path[1:], "/"+format)) // Hack
 	since, err := parseSince(r)
 	if err != nil {
@@ -270,6 +265,7 @@ func (s *Server) handleSubscribe(w http.ResponseWriter, r *http.Request, format
 		case <-r.Context().Done():
 			return nil
 		case <-time.After(s.config.KeepaliveInterval):
+			v.Keepalive()
 			if err := sub(newKeepaliveMessage(t.id)); err != nil { // Send keepalive message
 				return err
 			}
@@ -326,12 +322,12 @@ func (s *Server) updateStatsAndExpire() {
 
 	// Expire visitors from rate visitors map
 	for ip, v := range s.visitors {
-		if time.Since(v.seen) > visitorExpungeAfter {
+		if v.Stale() {
 			delete(s.visitors, ip)
 		}
 	}
 
-	// Prune old messages, remove topics without subscribers
+	// Prune old messages, remove subscriptions without subscribers
 	for _, t := range s.topics {
 		t.Prune(s.config.MessageBufferDuration)
 		subs, msgs := t.Stats()
@@ -362,12 +358,8 @@ func (s *Server) visitor(remoteAddr string) *visitor {
 	}
 	v, exists := s.visitors[ip]
 	if !exists {
-		v = &visitor{
-			rate.NewLimiter(s.config.Limit, s.config.LimitBurst),
-			time.Now(),
-		}
-		s.visitors[ip] = v
-		return v
+		s.visitors[ip] = newVisitor(s.config)
+		return s.visitors[ip]
 	}
 	v.seen = time.Now()
 	return v

+ 65 - 0
server/visitor.go

@@ -0,0 +1,65 @@
+package server
+
+import (
+	"golang.org/x/time/rate"
+	"heckel.io/ntfy/config"
+	"sync"
+	"time"
+)
+
+const (
+	visitorExpungeAfter = 30 * time.Minute
+)
+
+// visitor represents an API user, and its associated rate.Limiter used for rate limiting
+type visitor struct {
+	config        *config.Config
+	limiter       *rate.Limiter
+	subscriptions int
+	seen          time.Time
+	mu            sync.Mutex
+}
+
+func newVisitor(conf *config.Config) *visitor {
+	return &visitor{
+		config:  conf,
+		limiter: rate.NewLimiter(conf.RequestLimit, conf.RequestLimitBurst),
+		seen:    time.Now(),
+	}
+}
+
+func (v *visitor) RequestAllowed() error {
+	if !v.limiter.Allow() {
+		return errHTTPTooManyRequests
+	}
+	return nil
+}
+
+func (v *visitor) AddSubscription() error {
+	v.mu.Lock()
+	defer v.mu.Unlock()
+	if v.subscriptions >= v.config.SubscriptionLimit {
+		return errHTTPTooManyRequests
+	}
+	v.subscriptions++
+	return nil
+}
+
+func (v *visitor) RemoveSubscription() {
+	v.mu.Lock()
+	defer v.mu.Unlock()
+	v.subscriptions--
+}
+
+func (v *visitor) Keepalive() {
+	v.mu.Lock()
+	defer v.mu.Unlock()
+	v.seen = time.Now()
+}
+
+func (v *visitor) Stale() bool {
+	v.mu.Lock()
+	defer v.mu.Unlock()
+	v.seen = time.Now()
+	return time.Since(v.seen) > visitorExpungeAfter
+}