binwiederhier 8 mesi fa
parent
commit
d8c8f31846
5 ha cambiato i file con 39 aggiunte e 24 eliminazioni
  1. 9 8
      cmd/serve.go
  2. 3 3
      server/config.go
  3. 6 5
      server/smtp_server.go
  4. 15 8
      server/util.go
  5. 6 0
      server/visitor.go

+ 9 - 8
cmd/serve.go

@@ -5,13 +5,6 @@ package cmd
 import (
 	"errors"
 	"fmt"
-	"github.com/stripe/stripe-go/v74"
-	"github.com/urfave/cli/v2"
-	"github.com/urfave/cli/v2/altsrc"
-	"heckel.io/ntfy/v2/log"
-	"heckel.io/ntfy/v2/server"
-	"heckel.io/ntfy/v2/user"
-	"heckel.io/ntfy/v2/util"
 	"io/fs"
 	"math"
 	"net"
@@ -22,6 +15,14 @@ import (
 	"strings"
 	"syscall"
 	"time"
+
+	"github.com/stripe/stripe-go/v74"
+	"github.com/urfave/cli/v2"
+	"github.com/urfave/cli/v2/altsrc"
+	"heckel.io/ntfy/v2/log"
+	"heckel.io/ntfy/v2/server"
+	"heckel.io/ntfy/v2/user"
+	"heckel.io/ntfy/v2/util"
 )
 
 func init() {
@@ -473,7 +474,7 @@ func sigHandlerConfigReload(config string) {
 }
 
 func parseIPHostPrefix(host string) (prefixes []netip.Prefix, err error) {
-	// Try parsing as prefix, e.g. 10.0.1.0/24
+	// Try parsing as prefix, e.g. 10.0.1.0/24 or 2001:db8::/32
 	prefix, err := netip.ParsePrefix(host)
 	if err == nil {
 		prefixes = append(prefixes, prefix.Masked())

+ 3 - 3
server/config.go

@@ -143,9 +143,9 @@ type Config struct {
 	VisitorAuthFailureLimitReplenish     time.Duration
 	VisitorStatsResetTime                time.Time // Time of the day at which to reset visitor stats
 	VisitorSubscriberRateLimiting        bool      // Enable subscriber-based rate limiting for UnifiedPush topics
-	BehindProxy                          bool      // If true, the server will trust the proxy client IP header to determine the client IP address
-	ProxyForwardedHeader                 string    // The header field to read the real/client IP address from, if BehindProxy is true, defaults to "X-Forwarded-For"
-	ProxyTrustedAddresses                []string  // List of trusted proxy addresses that will be stripped from the Forwarded header if BehindProxy is true
+	BehindProxy                          bool      // If true, the server will trust the proxy client IP header to determine the client IP address (IPv4 and IPv6 supported)
+	ProxyForwardedHeader                 string    // The header field to read the real/client IP address from, if BehindProxy is true, defaults to "X-Forwarded-For" (IPv4 and IPv6 supported)
+	ProxyTrustedAddresses                []string  // List of trusted proxy addresses (IPv4 or IPv6) that will be stripped from the Forwarded header if BehindProxy is true
 	StripeSecretKey                      string
 	StripeWebhookKey                     string
 	StripePriceCacheDuration             time.Duration

+ 6 - 5
server/smtp_server.go

@@ -5,8 +5,6 @@ import (
 	"encoding/base64"
 	"errors"
 	"fmt"
-	"github.com/emersion/go-smtp"
-	"github.com/microcosm-cc/bluemonday"
 	"io"
 	"mime"
 	"mime/multipart"
@@ -18,6 +16,9 @@ import (
 	"regexp"
 	"strings"
 	"sync"
+
+	"github.com/emersion/go-smtp"
+	"github.com/microcosm-cc/bluemonday"
 )
 
 var (
@@ -191,9 +192,9 @@ func (s *smtpSession) publishMessage(m *message) error {
 	// Call HTTP handler with fake HTTP request
 	url := fmt.Sprintf("%s/%s", s.backend.config.BaseURL, m.Topic)
 	req, err := http.NewRequest("POST", url, strings.NewReader(m.Message))
-	req.RequestURI = "/" + m.Topic // just for the logs
-	req.RemoteAddr = remoteAddr    // rate limiting!!
-	req.Header.Set("X-Forwarded-For", remoteAddr)
+	req.RequestURI = "/" + m.Topic                                    // just for the logs
+	req.RemoteAddr = remoteAddr                                       // rate limiting!!
+	req.Header.Set(s.backend.config.ProxyForwardedHeader, remoteAddr) // Set X-Forwarded-For header
 	if err != nil {
 		return err
 	}

+ 15 - 8
server/util.go

@@ -4,7 +4,6 @@ import (
 	"context"
 	"errors"
 	"fmt"
-	"heckel.io/ntfy/v2/util"
 	"io"
 	"mime"
 	"net/http"
@@ -12,6 +11,8 @@ import (
 	"regexp"
 	"slices"
 	"strings"
+
+	"heckel.io/ntfy/v2/util"
 )
 
 var (
@@ -20,8 +21,9 @@ var (
 	// priorityHeaderIgnoreRegex matches specific patterns of the "Priority" header (RFC 9218), so that it can be ignored
 	priorityHeaderIgnoreRegex = regexp.MustCompile(`^u=\d,\s*(i|\d)$|^u=\d$`)
 
-	// forwardedHeaderRegex parses IPv4 addresses from the "Forwarded" header (RFC 7239)
-	forwardedHeaderRegex = regexp.MustCompile(`(?i)\bfor="?(\d{1,3}\.\d{1,3}\.\d{1,3}\.\d{1,3})"?`)
+	// forwardedHeaderRegex parses IPv4 and IPv6 addresses from the "Forwarded" header (RFC 7239)
+	// IPv6 addresses in Forwarded header are enclosed in square brackets, e.g. for="[2001:db8::1]"
+	forwardedHeaderRegex = regexp.MustCompile(`(?i)\\bfor=\"?((?:[0-9]{1,3}\.){3}[0-9]{1,3}|\[[0-9a-fA-F:]+\])\"?`)
 )
 
 func readBoolParam(r *http.Request, defaultValue bool, names ...string) bool {
@@ -103,7 +105,7 @@ func extractIPAddress(r *http.Request, behindProxy bool, proxyForwardedHeader st
 // then take the right-most address in the list (as this is the one added by our proxy server).
 // See https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/X-Forwarded-For for details.
 func extractIPAddressFromHeader(r *http.Request, forwardedHeader string, trustedAddresses []string) (netip.Addr, error) {
-	value := strings.TrimSpace(strings.ToLower(r.Header.Get(forwardedHeader)))
+	value := strings.TrimSpace(r.Header.Get(forwardedHeader))
 	if value == "" {
 		return netip.IPv4Unspecified(), fmt.Errorf("no %s header found", forwardedHeader)
 	}
@@ -111,12 +113,17 @@ func extractIPAddressFromHeader(r *http.Request, forwardedHeader string, trusted
 	addrsStrs := util.Map(util.SplitNoEmpty(value, ","), strings.TrimSpace)
 	var validAddrs []netip.Addr
 	for _, addrStr := range addrsStrs {
-		if addr, err := netip.ParseAddr(addrStr); err == nil {
-			validAddrs = append(validAddrs, addr)
-		} else if m := forwardedHeaderRegex.FindStringSubmatch(addrStr); len(m) == 2 {
-			if addr, err := netip.ParseAddr(m[1]); err == nil {
+		// Handle Forwarded header with for="[IPv6]" or for="IPv4"
+		if m := forwardedHeaderRegex.FindStringSubmatch(addrStr); len(m) == 2 {
+			addrRaw := m[1]
+			if strings.HasPrefix(addrRaw, "[") && strings.HasSuffix(addrRaw, "]") {
+				addrRaw = addrRaw[1 : len(addrRaw)-1]
+			}
+			if addr, err := netip.ParseAddr(addrRaw); err == nil {
 				validAddrs = append(validAddrs, addr)
 			}
+		} else if addr, err := netip.ParseAddr(addrStr); err == nil {
+			validAddrs = append(validAddrs, addr)
 		}
 	}
 	// Filter out proxy addresses

+ 6 - 0
server/visitor.go

@@ -528,5 +528,11 @@ func visitorID(ip netip.Addr, u *user.User) string {
 	if u != nil && u.Tier != nil {
 		return fmt.Sprintf("user:%s", u.ID)
 	}
+	if ip.Is6() {
+		// IPv6 addresses are too long to be used as visitor IDs, so we use the first 8 bytes
+		ip = netip.PrefixFrom(ip, 64).Masked().Addr()
+	} else if ip.Is4() {
+		ip = netip.PrefixFrom(ip, 20).Masked().Addr()
+	}
 	return fmt.Sprintf("ip:%s", ip.String())
 }