binwiederhier 8 ay önce
ebeveyn
işleme
54514454bf
6 değiştirilmiş dosya ile 125 ekleme ve 20 silme
  1. 13 3
      cmd/serve.go
  2. 6 0
      server/config.go
  3. 1 1
      server/server.go
  4. 88 4
      server/server_test.go
  5. 8 3
      server/util.go
  6. 9 9
      server/visitor.go

+ 13 - 3
cmd/serve.go

@@ -80,6 +80,7 @@ var flagsServe = append(
 	altsrc.NewStringFlag(&cli.StringFlag{Name: "message-delay-limit", Aliases: []string{"message_delay_limit"}, EnvVars: []string{"NTFY_MESSAGE_DELAY_LIMIT"}, Value: util.FormatDuration(server.DefaultMessageDelayMax), Usage: "max duration a message can be scheduled into the future"}),
 	altsrc.NewIntFlag(&cli.IntFlag{Name: "global-topic-limit", Aliases: []string{"global_topic_limit", "T"}, EnvVars: []string{"NTFY_GLOBAL_TOPIC_LIMIT"}, Value: server.DefaultTotalTopicLimit, Usage: "total number of topics allowed"}),
 	altsrc.NewIntFlag(&cli.IntFlag{Name: "visitor-subscription-limit", Aliases: []string{"visitor_subscription_limit"}, EnvVars: []string{"NTFY_VISITOR_SUBSCRIPTION_LIMIT"}, Value: server.DefaultVisitorSubscriptionLimit, Usage: "number of subscriptions per visitor"}),
+	altsrc.NewBoolFlag(&cli.BoolFlag{Name: "visitor-subscriber-rate-limiting", Aliases: []string{"visitor_subscriber_rate_limiting"}, EnvVars: []string{"NTFY_VISITOR_SUBSCRIBER_RATE_LIMITING"}, Value: false, Usage: "enables subscriber-based rate limiting"}),
 	altsrc.NewStringFlag(&cli.StringFlag{Name: "visitor-attachment-total-size-limit", Aliases: []string{"visitor_attachment_total_size_limit"}, EnvVars: []string{"NTFY_VISITOR_ATTACHMENT_TOTAL_SIZE_LIMIT"}, Value: util.FormatSize(server.DefaultVisitorAttachmentTotalSizeLimit), Usage: "total storage limit used for attachments per visitor"}),
 	altsrc.NewStringFlag(&cli.StringFlag{Name: "visitor-attachment-daily-bandwidth-limit", Aliases: []string{"visitor_attachment_daily_bandwidth_limit"}, EnvVars: []string{"NTFY_VISITOR_ATTACHMENT_DAILY_BANDWIDTH_LIMIT"}, Value: "500M", Usage: "total daily attachment download/upload bandwidth limit per visitor"}),
 	altsrc.NewIntFlag(&cli.IntFlag{Name: "visitor-request-limit-burst", Aliases: []string{"visitor_request_limit_burst"}, EnvVars: []string{"NTFY_VISITOR_REQUEST_LIMIT_BURST"}, Value: server.DefaultVisitorRequestLimitBurst, Usage: "initial limit of requests per visitor"}),
@@ -88,7 +89,8 @@ var flagsServe = append(
 	altsrc.NewIntFlag(&cli.IntFlag{Name: "visitor-message-daily-limit", Aliases: []string{"visitor_message_daily_limit"}, EnvVars: []string{"NTFY_VISITOR_MESSAGE_DAILY_LIMIT"}, Value: server.DefaultVisitorMessageDailyLimit, Usage: "max messages per visitor per day, derived from request limit if unset"}),
 	altsrc.NewIntFlag(&cli.IntFlag{Name: "visitor-email-limit-burst", Aliases: []string{"visitor_email_limit_burst"}, EnvVars: []string{"NTFY_VISITOR_EMAIL_LIMIT_BURST"}, Value: server.DefaultVisitorEmailLimitBurst, Usage: "initial limit of e-mails per visitor"}),
 	altsrc.NewStringFlag(&cli.StringFlag{Name: "visitor-email-limit-replenish", Aliases: []string{"visitor_email_limit_replenish"}, EnvVars: []string{"NTFY_VISITOR_EMAIL_LIMIT_REPLENISH"}, Value: util.FormatDuration(server.DefaultVisitorEmailLimitReplenish), Usage: "interval at which burst limit is replenished (one per x)"}),
-	altsrc.NewBoolFlag(&cli.BoolFlag{Name: "visitor-subscriber-rate-limiting", Aliases: []string{"visitor_subscriber_rate_limiting"}, EnvVars: []string{"NTFY_VISITOR_SUBSCRIBER_RATE_LIMITING"}, Value: false, Usage: "enables subscriber-based rate limiting"}),
+	altsrc.NewIntFlag(&cli.IntFlag{Name: "visitor-prefix-bits-ipv4", Aliases: []string{"visitor_prefix_bits_ipv4"}, EnvVars: []string{"NTFY_VISITOR_PREFIX_BITS_IPV4"}, Value: server.DefaultVisitorPrefixBitsIPv4, Usage: "number of bits of the IPv4 address to use for rate limiting (default: 32, full address)"}),
+	altsrc.NewIntFlag(&cli.IntFlag{Name: "visitor-prefix-bits-ipv6", Aliases: []string{"visitor_prefix_bits_ipv6"}, EnvVars: []string{"NTFY_VISITOR_PREFIX_BITS_IPV6"}, Value: server.DefaultVisitorPrefixBitsIPv6, Usage: "number of bits of the IPv6 address to use for rate limiting (default: 64, /64 subnet)"}),
 	altsrc.NewBoolFlag(&cli.BoolFlag{Name: "behind-proxy", Aliases: []string{"behind_proxy", "P"}, EnvVars: []string{"NTFY_BEHIND_PROXY"}, Value: false, Usage: "if set, use forwarded header (e.g. X-Forwarded-For, X-Client-IP) to determine visitor IP address (for rate limiting)"}),
 	altsrc.NewStringFlag(&cli.StringFlag{Name: "proxy-forwarded-header", Aliases: []string{"proxy_forwarded_header"}, EnvVars: []string{"NTFY_PROXY_FORWARDED_HEADER"}, Value: "X-Forwarded-For", Usage: "use specified header to determine visitor IP address (for rate limiting)"}),
 	altsrc.NewStringFlag(&cli.StringFlag{Name: "proxy-trusted-addresses", Aliases: []string{"proxy_trusted_addresses"}, EnvVars: []string{"NTFY_PROXY_TRUSTED_ADDRESSES"}, Value: "", Usage: "comma-separated list of trusted IP addresses to remove from forwarded header"}),
@@ -192,6 +194,8 @@ func execServe(c *cli.Context) error {
 	visitorMessageDailyLimit := c.Int("visitor-message-daily-limit")
 	visitorEmailLimitBurst := c.Int("visitor-email-limit-burst")
 	visitorEmailLimitReplenishStr := c.String("visitor-email-limit-replenish")
+	visitorPrefixBitsIPv4 := c.Int("visitor-prefix-bits-ipv4")
+	visitorPrefixBitsIPv6 := c.Int("visitor-prefix-bits-ipv6")
 	behindProxy := c.Bool("behind-proxy")
 	proxyForwardedHeader := c.String("proxy-forwarded-header")
 	proxyTrustedAddresses := util.SplitNoEmpty(c.String("proxy-trusted-addresses"), ",")
@@ -325,6 +329,10 @@ func execServe(c *cli.Context) error {
 		return errors.New("web push expiry warning duration cannot be higher than web push expiry duration")
 	} else if behindProxy && proxyForwardedHeader == "" {
 		return errors.New("if behind-proxy is set, proxy-forwarded-header must also be set")
+	} else if visitorPrefixBitsIPv4 < 1 || visitorPrefixBitsIPv4 > 32 {
+		return errors.New("visitor-prefix-bits-ipv4 must be between 1 and 32")
+	} else if visitorPrefixBitsIPv6 < 1 || visitorPrefixBitsIPv6 > 128 {
+		return errors.New("visitor-prefix-bits-ipv6 must be between 1 and 128")
 	}
 
 	// Backwards compatibility
@@ -413,6 +421,7 @@ func execServe(c *cli.Context) error {
 	conf.MessageDelayMax = messageDelayLimit
 	conf.TotalTopicLimit = totalTopicLimit
 	conf.VisitorSubscriptionLimit = visitorSubscriptionLimit
+	conf.VisitorSubscriberRateLimiting = visitorSubscriberRateLimiting
 	conf.VisitorAttachmentTotalSizeLimit = visitorAttachmentTotalSizeLimit
 	conf.VisitorAttachmentDailyBandwidthLimit = visitorAttachmentDailyBandwidthLimit
 	conf.VisitorRequestLimitBurst = visitorRequestLimitBurst
@@ -421,7 +430,8 @@ func execServe(c *cli.Context) error {
 	conf.VisitorMessageDailyLimit = visitorMessageDailyLimit
 	conf.VisitorEmailLimitBurst = visitorEmailLimitBurst
 	conf.VisitorEmailLimitReplenish = visitorEmailLimitReplenish
-	conf.VisitorSubscriberRateLimiting = visitorSubscriberRateLimiting
+	conf.VisitorPrefixBitsIPv4 = visitorPrefixBitsIPv4
+	conf.VisitorPrefixBitsIPv6 = visitorPrefixBitsIPv6
 	conf.BehindProxy = behindProxy
 	conf.ProxyForwardedHeader = proxyForwardedHeader
 	conf.ProxyTrustedAddresses = proxyTrustedAddresses
@@ -434,7 +444,6 @@ func execServe(c *cli.Context) error {
 	conf.EnableMetrics = enableMetrics
 	conf.MetricsListenHTTP = metricsListenHTTP
 	conf.ProfileListenHTTP = profileListenHTTP
-	conf.Version = c.App.Version
 	conf.WebPushPrivateKey = webPushPrivateKey
 	conf.WebPushPublicKey = webPushPublicKey
 	conf.WebPushFile = webPushFile
@@ -442,6 +451,7 @@ func execServe(c *cli.Context) error {
 	conf.WebPushStartupQueries = webPushStartupQueries
 	conf.WebPushExpiryDuration = webPushExpiryDuration
 	conf.WebPushExpiryWarningDuration = webPushExpiryWarningDuration
+	conf.Version = c.App.Version
 
 	// Set up hot-reloading of config
 	go sigHandlerConfigReload(config)

+ 6 - 0
server/config.go

@@ -61,6 +61,8 @@ const (
 	DefaultVisitorAuthFailureLimitReplenish     = time.Minute
 	DefaultVisitorAttachmentTotalSizeLimit      = 100 * 1024 * 1024 // 100 MB
 	DefaultVisitorAttachmentDailyBandwidthLimit = 500 * 1024 * 1024 // 500 MB
+	DefaultVisitorPrefixBitsIPv4                = 32                // Use the entire IPv4 address for rate limiting
+	DefaultVisitorPrefixBitsIPv6                = 64                // Use /64 for IPv6 rate limiting
 )
 
 var (
@@ -143,6 +145,8 @@ type Config struct {
 	VisitorAuthFailureLimitReplenish     time.Duration
 	VisitorStatsResetTime                time.Time // Time of the day at which to reset visitor stats
 	VisitorSubscriberRateLimiting        bool      // Enable subscriber-based rate limiting for UnifiedPush topics
+	VisitorPrefixBitsIPv4                int       // Number of bits for IPv4 rate limiting (default: 32)
+	VisitorPrefixBitsIPv6                int       // Number of bits for IPv6 rate limiting (default: 64)
 	BehindProxy                          bool      // If true, the server will trust the proxy client IP header to determine the client IP address (IPv4 and IPv6 supported)
 	ProxyForwardedHeader                 string    // The header field to read the real/client IP address from, if BehindProxy is true, defaults to "X-Forwarded-For" (IPv4 and IPv6 supported)
 	ProxyTrustedAddresses                []string  // List of trusted proxy addresses (IPv4 or IPv6) that will be stripped from the Forwarded header if BehindProxy is true
@@ -234,6 +238,8 @@ func NewConfig() *Config {
 		VisitorAuthFailureLimitReplenish:     DefaultVisitorAuthFailureLimitReplenish,
 		VisitorStatsResetTime:                DefaultVisitorStatsResetTime,
 		VisitorSubscriberRateLimiting:        false,
+		VisitorPrefixBitsIPv4:                32,                // Default: use full IPv4 address
+		VisitorPrefixBitsIPv6:                64,                // Default: use /64 for IPv6
 		BehindProxy:                          false,             // If true, the server will trust the proxy client IP header to determine the client IP address
 		ProxyForwardedHeader:                 "X-Forwarded-For", // Default header for reverse proxy client IPs
 		StripeSecretKey:                      "",

+ 1 - 1
server/server.go

@@ -2023,7 +2023,7 @@ func (s *Server) authenticateBearerAuth(r *http.Request, token string) (*user.Us
 func (s *Server) visitor(ip netip.Addr, user *user.User) *visitor {
 	s.mu.Lock()
 	defer s.mu.Unlock()
-	id := visitorID(ip, user)
+	id := visitorID(ip, user, s.config)
 	v, exists := s.visitors[id]
 	if !exists {
 		s.visitors[id] = newVisitor(s.config, s.messageCache, s.userManager, ip, user)

+ 88 - 4
server/server_test.go

@@ -1169,7 +1169,7 @@ func (t *testMailer) Count() int {
 	return t.count
 }
 
-func TestServer_PublishTooRequests_Defaults(t *testing.T) {
+func TestServer_PublishTooManyRequests_Defaults(t *testing.T) {
 	s := newTestServer(t, newTestConfig(t))
 	for i := 0; i < 60; i++ {
 		response := request(t, s, "PUT", "/mytopic", fmt.Sprintf("message %d", i), nil)
@@ -1179,7 +1179,50 @@ func TestServer_PublishTooRequests_Defaults(t *testing.T) {
 	require.Equal(t, 429, response.Code)
 }
 
-func TestServer_PublishTooRequests_Defaults_ExemptHosts(t *testing.T) {
+func TestServer_PublishTooManyRequests_Defaults_IPv6(t *testing.T) {
+	s := newTestServer(t, newTestConfig(t))
+	overrideRemoteAddr1 := func(r *http.Request) {
+		r.RemoteAddr = "[2001:db8:9999:8888:1::1]:1234"
+	}
+	overrideRemoteAddr2 := func(r *http.Request) {
+		r.RemoteAddr = "[2001:db8:9999:8888:2::1]:1234" // Same /64
+	}
+	for i := 0; i < 30; i++ {
+		response := request(t, s, "PUT", "/mytopic", fmt.Sprintf("message %d", i), nil, overrideRemoteAddr1)
+		require.Equal(t, 200, response.Code)
+	}
+	for i := 0; i < 30; i++ {
+		response := request(t, s, "PUT", "/mytopic", fmt.Sprintf("message %d", i), nil, overrideRemoteAddr2)
+		require.Equal(t, 200, response.Code)
+	}
+	response := request(t, s, "PUT", "/mytopic", "message", nil, overrideRemoteAddr1)
+	require.Equal(t, 429, response.Code)
+}
+
+func TestServer_PublishTooManyRequests_IPv6_Slash48(t *testing.T) {
+	c := newTestConfig(t)
+	c.VisitorRequestLimitBurst = 6
+	c.VisitorPrefixBitsIPv6 = 48 // Use /48 for IPv6 prefixes
+	s := newTestServer(t, c)
+	overrideRemoteAddr1 := func(r *http.Request) {
+		r.RemoteAddr = "[2001:db8:9999::1]:1234"
+	}
+	overrideRemoteAddr2 := func(r *http.Request) {
+		r.RemoteAddr = "[2001:db8:9999::2]:1234" // Same /48
+	}
+	for i := 0; i < 3; i++ {
+		response := request(t, s, "PUT", "/mytopic", fmt.Sprintf("message %d", i), nil, overrideRemoteAddr1)
+		require.Equal(t, 200, response.Code)
+	}
+	for i := 0; i < 3; i++ {
+		response := request(t, s, "PUT", "/mytopic", fmt.Sprintf("message %d", i), nil, overrideRemoteAddr2)
+		require.Equal(t, 200, response.Code)
+	}
+	response := request(t, s, "PUT", "/mytopic", "message", nil, overrideRemoteAddr1)
+	require.Equal(t, 429, response.Code)
+}
+
+func TestServer_PublishTooManyRequests_Defaults_ExemptHosts(t *testing.T) {
 	c := newTestConfig(t)
 	c.VisitorRequestLimitBurst = 3
 	c.VisitorRequestExemptIPAddrs = []netip.Prefix{netip.MustParsePrefix("9.9.9.9/32")} // see request()
@@ -1190,7 +1233,21 @@ func TestServer_PublishTooRequests_Defaults_ExemptHosts(t *testing.T) {
 	}
 }
 
-func TestServer_PublishTooRequests_Defaults_ExemptHosts_MessageDailyLimit(t *testing.T) {
+func TestServer_PublishTooManyRequests_Defaults_ExemptHosts_IPv6(t *testing.T) {
+	c := newTestConfig(t)
+	c.VisitorRequestLimitBurst = 3
+	c.VisitorRequestExemptIPAddrs = []netip.Prefix{netip.MustParsePrefix("2001:db8:9999::/48")}
+	s := newTestServer(t, c)
+	overrideRemoteAddr := func(r *http.Request) {
+		r.RemoteAddr = "[2001:db8:9999::1]:1234"
+	}
+	for i := 0; i < 5; i++ { // > 3
+		response := request(t, s, "PUT", "/mytopic", fmt.Sprintf("message %d", i), nil, overrideRemoteAddr)
+		require.Equal(t, 200, response.Code)
+	}
+}
+
+func TestServer_PublishTooManyRequests_Defaults_ExemptHosts_MessageDailyLimit(t *testing.T) {
 	c := newTestConfig(t)
 	c.VisitorRequestLimitBurst = 10
 	c.VisitorMessageDailyLimit = 4
@@ -1202,7 +1259,7 @@ func TestServer_PublishTooRequests_Defaults_ExemptHosts_MessageDailyLimit(t *tes
 	}
 }
 
-func TestServer_PublishTooRequests_ShortReplenish(t *testing.T) {
+func TestServer_PublishTooManyRequests_ShortReplenish(t *testing.T) {
 	t.Parallel()
 	c := newTestConfig(t)
 	c.VisitorRequestLimitBurst = 60
@@ -2244,6 +2301,19 @@ func TestServer_Visitor_Custom_ClientIP_Header(t *testing.T) {
 	require.Equal(t, "1.2.3.4", v.ip.String())
 }
 
+func TestServer_Visitor_Custom_ClientIP_Header_IPv6(t *testing.T) {
+	c := newTestConfig(t)
+	c.BehindProxy = true
+	c.ProxyForwardedHeader = "X-Client-IP"
+	s := newTestServer(t, c)
+	r, _ := http.NewRequest("GET", "/bla", nil)
+	r.RemoteAddr = "[2001:db8:9999::1]:1234"
+	r.Header.Set("X-Client-IP", "2001:db8:7777::1")
+	v, err := s.maybeAuthenticate(r)
+	require.Nil(t, err)
+	require.Equal(t, "2001:db8:7777::1", v.ip.String())
+}
+
 func TestServer_Visitor_Custom_Forwarded_Header(t *testing.T) {
 	c := newTestConfig(t)
 	c.BehindProxy = true
@@ -2258,6 +2328,20 @@ func TestServer_Visitor_Custom_Forwarded_Header(t *testing.T) {
 	require.Equal(t, "5.6.7.8", v.ip.String())
 }
 
+func TestServer_Visitor_Custom_Forwarded_Header_IPv6(t *testing.T) {
+	c := newTestConfig(t)
+	c.BehindProxy = true
+	c.ProxyForwardedHeader = "Forwarded"
+	c.ProxyTrustedAddresses = []string{"2001:db8:1111::1"}
+	s := newTestServer(t, c)
+	r, _ := http.NewRequest("GET", "/bla", nil)
+	r.RemoteAddr = "[2001:db8:2222::1]:1234"
+	r.Header.Set("Forwarded", " for=[2001:db8:1111::1], by=example.com;for=[2001:db8:3333::1]")
+	v, err := s.maybeAuthenticate(r)
+	require.Nil(t, err)
+	require.Equal(t, "2001:db8:3333::1", v.ip.String())
+}
+
 func TestServer_PublishWhileUpdatingStatsWithLotsOfMessages(t *testing.T) {
 	t.Parallel()
 	count := 50000

+ 8 - 3
server/util.go

@@ -22,8 +22,13 @@ var (
 	priorityHeaderIgnoreRegex = regexp.MustCompile(`^u=\d,\s*(i|\d)$|^u=\d$`)
 
 	// forwardedHeaderRegex parses IPv4 and IPv6 addresses from the "Forwarded" header (RFC 7239)
-	// IPv6 addresses in Forwarded header are enclosed in square brackets, e.g. for="[2001:db8::1]"
-	forwardedHeaderRegex = regexp.MustCompile(`(?i)\\bfor=\"?((?:[0-9]{1,3}\.){3}[0-9]{1,3}|\[[0-9a-fA-F:]+\])\"?`)
+	// IPv6 addresses in Forwarded header are enclosed in square brackets. The port is optional.
+	//
+	// Examples:
+	//  for="1.2.3.4"
+	//  for="[2001:db8::1]"; for=1.2.3.4:8080, by=phil
+	//  for="1.2.3.4:8080"
+	forwardedHeaderRegex = regexp.MustCompile(`(?i)\bfor="?(\d{1,3}\.\d{1,3}\.\d{1,3}\.\d{1,3}|\[[0-9a-f:]+])(?::\d+)?"?`)
 )
 
 func readBoolParam(r *http.Request, defaultValue bool, names ...string) bool {
@@ -105,7 +110,7 @@ func extractIPAddress(r *http.Request, behindProxy bool, proxyForwardedHeader st
 // then take the right-most address in the list (as this is the one added by our proxy server).
 // See https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/X-Forwarded-For for details.
 func extractIPAddressFromHeader(r *http.Request, forwardedHeader string, trustedAddresses []string) (netip.Addr, error) {
-	value := strings.TrimSpace(r.Header.Get(forwardedHeader))
+	value := strings.TrimSpace(strings.ToLower(r.Header.Get(forwardedHeader)))
 	if value == "" {
 		return netip.IPv4Unspecified(), fmt.Errorf("no %s header found", forwardedHeader)
 	}

+ 9 - 9
server/visitor.go

@@ -2,13 +2,13 @@ package server
 
 import (
 	"fmt"
-	"heckel.io/ntfy/v2/log"
-	"heckel.io/ntfy/v2/user"
 	"net/netip"
 	"sync"
 	"time"
 
 	"golang.org/x/time/rate"
+	"heckel.io/ntfy/v2/log"
+	"heckel.io/ntfy/v2/user"
 	"heckel.io/ntfy/v2/util"
 )
 
@@ -151,7 +151,7 @@ func (v *visitor) Context() log.Context {
 func (v *visitor) contextNoLock() log.Context {
 	info := v.infoLightNoLock()
 	fields := log.Context{
-		"visitor_id":                     visitorID(v.ip, v.user),
+		"visitor_id":                     visitorID(v.ip, v.user, v.config),
 		"visitor_ip":                     v.ip.String(),
 		"visitor_seen":                   util.FormatTime(v.seen),
 		"visitor_messages":               info.Stats.Messages,
@@ -524,15 +524,15 @@ func dailyLimitToRate(limit int64) rate.Limit {
 	return rate.Limit(limit) * rate.Every(oneDay)
 }
 
-func visitorID(ip netip.Addr, u *user.User) string {
+// visitorID returns a unique identifier for a visitor based on user or IP, using configurable prefix bits for IPv4/IPv6
+func visitorID(ip netip.Addr, u *user.User, conf *Config) string {
 	if u != nil && u.Tier != nil {
 		return fmt.Sprintf("user:%s", u.ID)
 	}
-	if ip.Is6() {
-		// IPv6 addresses are too long to be used as visitor IDs, so we use the first 8 bytes
-		ip = netip.PrefixFrom(ip, 64).Masked().Addr()
-	} else if ip.Is4() {
-		ip = netip.PrefixFrom(ip, 20).Masked().Addr()
+	if ip.Is4() {
+		ip = netip.PrefixFrom(ip, conf.VisitorPrefixBitsIPv4).Masked().Addr()
+	} else if ip.Is6() {
+		ip = netip.PrefixFrom(ip, conf.VisitorPrefixBitsIPv6).Masked().Addr()
 	}
 	return fmt.Sprintf("ip:%s", ip.String())
 }