Quellcode durchsuchen

Rate limit exemption; relates to #144

Philipp Heckel vor 4 Jahren
Ursprung
Commit
2ad0802b65
3 geänderte Dateien mit 29 neuen und 8 gelöschten Zeilen
  1. 17 0
      cmd/serve.go
  2. 2 0
      server/config.go
  3. 10 8
      server/server.go

+ 17 - 0
cmd/serve.go

@@ -9,6 +9,7 @@ import (
 	"heckel.io/ntfy/util"
 	"log"
 	"math"
+	"net"
 	"strings"
 	"time"
 )
@@ -45,6 +46,7 @@ var flagsServe = []cli.Flag{
 	altsrc.NewStringFlag(&cli.StringFlag{Name: "visitor-attachment-daily-bandwidth-limit", EnvVars: []string{"NTFY_VISITOR_ATTACHMENT_DAILY_BANDWIDTH_LIMIT"}, Value: "500M", Usage: "total daily attachment download/upload bandwidth limit per visitor"}),
 	altsrc.NewIntFlag(&cli.IntFlag{Name: "visitor-request-limit-burst", EnvVars: []string{"NTFY_VISITOR_REQUEST_LIMIT_BURST"}, Value: server.DefaultVisitorRequestLimitBurst, Usage: "initial limit of requests per visitor"}),
 	altsrc.NewDurationFlag(&cli.DurationFlag{Name: "visitor-request-limit-replenish", EnvVars: []string{"NTFY_VISITOR_REQUEST_LIMIT_REPLENISH"}, Value: server.DefaultVisitorRequestLimitReplenish, Usage: "interval at which burst limit is replenished (one per x)"}),
+	altsrc.NewStringFlag(&cli.StringFlag{Name: "visitor-request-limit-exempt-hosts", EnvVars: []string{"NTFY_VISITOR_REQUEST_LIMIT_EXEMPT_HOSTS"}, Value: "", Usage: "hostnames and/or IP addresses of hosts that will be exempt from the visitor request limit"}),
 	altsrc.NewIntFlag(&cli.IntFlag{Name: "visitor-email-limit-burst", EnvVars: []string{"NTFY_VISITOR_EMAIL_LIMIT_BURST"}, Value: server.DefaultVisitorEmailLimitBurst, Usage: "initial limit of e-mails per visitor"}),
 	altsrc.NewDurationFlag(&cli.DurationFlag{Name: "visitor-email-limit-replenish", EnvVars: []string{"NTFY_VISITOR_EMAIL_LIMIT_REPLENISH"}, Value: server.DefaultVisitorEmailLimitReplenish, Usage: "interval at which burst limit is replenished (one per x)"}),
 	altsrc.NewBoolFlag(&cli.BoolFlag{Name: "behind-proxy", Aliases: []string{"P"}, EnvVars: []string{"NTFY_BEHIND_PROXY"}, Value: false, Usage: "if set, use X-Forwarded-For header to determine visitor IP address (for rate limiting)"}),
@@ -104,6 +106,7 @@ func execServe(c *cli.Context) error {
 	visitorAttachmentDailyBandwidthLimitStr := c.String("visitor-attachment-daily-bandwidth-limit")
 	visitorRequestLimitBurst := c.Int("visitor-request-limit-burst")
 	visitorRequestLimitReplenish := c.Duration("visitor-request-limit-replenish")
+	visitorRequestLimitExemptHosts := util.SplitNoEmpty(c.String("visitor-request-limit-exempt-hosts"), ",")
 	visitorEmailLimitBurst := c.Int("visitor-email-limit-burst")
 	visitorEmailLimitReplenish := c.Duration("visitor-email-limit-replenish")
 	behindProxy := c.Bool("behind-proxy")
@@ -164,6 +167,19 @@ func execServe(c *cli.Context) error {
 		return fmt.Errorf("config option visitor-attachment-daily-bandwidth-limit must be lower than %d", math.MaxInt)
 	}
 
+	// Resolve hosts
+	visitorRequestLimitExemptIPs := make([]string, 0)
+	for _, host := range visitorRequestLimitExemptHosts {
+		ips, err := net.LookupIP(host)
+		if err != nil {
+			log.Printf("cannot resolve host %s: %s, ignoring visitor request exemption", host, err.Error())
+			continue
+		}
+		for _, ip := range ips {
+			visitorRequestLimitExemptIPs = append(visitorRequestLimitExemptIPs, ip.String())
+		}
+	}
+
 	// Run server
 	conf := server.NewConfig()
 	conf.BaseURL = baseURL
@@ -197,6 +213,7 @@ func execServe(c *cli.Context) error {
 	conf.VisitorAttachmentDailyBandwidthLimit = int(visitorAttachmentDailyBandwidthLimit)
 	conf.VisitorRequestLimitBurst = visitorRequestLimitBurst
 	conf.VisitorRequestLimitReplenish = visitorRequestLimitReplenish
+	conf.VisitorRequestExemptIPAddrs = visitorRequestLimitExemptIPs
 	conf.VisitorEmailLimitBurst = visitorEmailLimitBurst
 	conf.VisitorEmailLimitReplenish = visitorEmailLimitReplenish
 	conf.BehindProxy = behindProxy

+ 2 - 0
server/config.go

@@ -83,6 +83,7 @@ type Config struct {
 	VisitorAttachmentDailyBandwidthLimit int
 	VisitorRequestLimitBurst             int
 	VisitorRequestLimitReplenish         time.Duration
+	VisitorRequestExemptIPAddrs          []string
 	VisitorEmailLimitBurst               int
 	VisitorEmailLimitReplenish           time.Duration
 	BehindProxy                          bool
@@ -120,6 +121,7 @@ func NewConfig() *Config {
 		VisitorAttachmentDailyBandwidthLimit: DefaultVisitorAttachmentDailyBandwidthLimit,
 		VisitorRequestLimitBurst:             DefaultVisitorRequestLimitBurst,
 		VisitorRequestLimitReplenish:         DefaultVisitorRequestLimitReplenish,
+		VisitorRequestExemptIPAddrs:          make([]string, 0),
 		VisitorEmailLimitBurst:               DefaultVisitorEmailLimitBurst,
 		VisitorEmailLimitReplenish:           DefaultVisitorEmailLimitReplenish,
 		BehindProxy:                          false,

+ 10 - 8
server/server.go

@@ -251,16 +251,17 @@ func (s *Server) Stop() {
 }
 
 func (s *Server) handle(w http.ResponseWriter, r *http.Request) {
-	if err := s.handleInternal(w, r); err != nil {
+	v := s.visitor(r)
+	if err := s.handleInternal(w, r, v); err != nil {
 		if websocket.IsWebSocketUpgrade(r) {
-			log.Printf("[%s] WS %s %s - %s", r.RemoteAddr, r.Method, r.URL.Path, err.Error())
+			log.Printf("[%s] WS %s %s - %s", v.ip, r.Method, r.URL.Path, err.Error())
 			return // Do not attempt to write to upgraded connection
 		}
 		httpErr, ok := err.(*errHTTP)
 		if !ok {
 			httpErr = errHTTPInternalError
 		}
-		log.Printf("[%s] HTTP %s %s - %d - %d - %s", r.RemoteAddr, r.Method, r.URL.Path, httpErr.HTTPCode, httpErr.Code, err.Error())
+		log.Printf("[%s] HTTP %s %s - %d - %d - %s", v.ip, r.Method, r.URL.Path, httpErr.HTTPCode, httpErr.Code, err.Error())
 		w.Header().Set("Content-Type", "application/json")
 		w.Header().Set("Access-Control-Allow-Origin", "*") // CORS, allow cross-origin requests
 		w.WriteHeader(httpErr.HTTPCode)
@@ -268,8 +269,7 @@ func (s *Server) handle(w http.ResponseWriter, r *http.Request) {
 	}
 }
 
-func (s *Server) handleInternal(w http.ResponseWriter, r *http.Request) error {
-	v := s.visitor(r)
+func (s *Server) handleInternal(w http.ResponseWriter, r *http.Request, v *visitor) error {
 	if r.Method == http.MethodGet && r.URL.Path == "/" {
 		return s.handleHome(w, r)
 	} else if r.Method == http.MethodGet && r.URL.Path == "/example.html" {
@@ -404,14 +404,14 @@ func (s *Server) handlePublish(w http.ResponseWriter, r *http.Request, v *visito
 	if s.firebase != nil && firebase && !delayed {
 		go func() {
 			if err := s.firebase(m); err != nil {
-				log.Printf("Unable to publish to Firebase: %v", err.Error())
+				log.Printf("[%s] FB - Unable to publish to Firebase: %v", v.ip, err.Error())
 			}
 		}()
 	}
 	if s.mailer != nil && email != "" && !delayed {
 		go func() {
 			if err := s.mailer.Send(v.ip, email, m); err != nil {
-				log.Printf("Unable to send email: %v", err.Error())
+				log.Printf("[%s] MAIL - Unable to send email: %v", v.ip, err.Error())
 			}
 		}()
 	}
@@ -1063,7 +1063,9 @@ func (s *Server) sendDelayedMessages() error {
 
 func (s *Server) limitRequests(next handleFunc) handleFunc {
 	return func(w http.ResponseWriter, r *http.Request, v *visitor) error {
-		if err := v.RequestAllowed(); err != nil {
+		if util.InStringList(s.config.VisitorRequestExemptIPAddrs, v.ip) {
+			return next(w, r, v)
+		} else if err := v.RequestAllowed(); err != nil {
 			return errHTTPTooManyRequestsLimitRequests
 		}
 		return next(w, r, v)