binwiederhier 8 ماه پیش
والد
کامیت
d8c8f31846
5فایلهای تغییر یافته به همراه39 افزوده شده و 24 حذف شده
  1. 9 8
      cmd/serve.go
  2. 3 3
      server/config.go
  3. 6 5
      server/smtp_server.go
  4. 15 8
      server/util.go
  5. 6 0
      server/visitor.go

+ 9 - 8
cmd/serve.go

@@ -5,13 +5,6 @@ package cmd
 import (
 import (
 	"errors"
 	"errors"
 	"fmt"
 	"fmt"
-	"github.com/stripe/stripe-go/v74"
-	"github.com/urfave/cli/v2"
-	"github.com/urfave/cli/v2/altsrc"
-	"heckel.io/ntfy/v2/log"
-	"heckel.io/ntfy/v2/server"
-	"heckel.io/ntfy/v2/user"
-	"heckel.io/ntfy/v2/util"
 	"io/fs"
 	"io/fs"
 	"math"
 	"math"
 	"net"
 	"net"
@@ -22,6 +15,14 @@ import (
 	"strings"
 	"strings"
 	"syscall"
 	"syscall"
 	"time"
 	"time"
+
+	"github.com/stripe/stripe-go/v74"
+	"github.com/urfave/cli/v2"
+	"github.com/urfave/cli/v2/altsrc"
+	"heckel.io/ntfy/v2/log"
+	"heckel.io/ntfy/v2/server"
+	"heckel.io/ntfy/v2/user"
+	"heckel.io/ntfy/v2/util"
 )
 )
 
 
 func init() {
 func init() {
@@ -473,7 +474,7 @@ func sigHandlerConfigReload(config string) {
 }
 }
 
 
 func parseIPHostPrefix(host string) (prefixes []netip.Prefix, err error) {
 func parseIPHostPrefix(host string) (prefixes []netip.Prefix, err error) {
-	// Try parsing as prefix, e.g. 10.0.1.0/24
+	// Try parsing as prefix, e.g. 10.0.1.0/24 or 2001:db8::/32
 	prefix, err := netip.ParsePrefix(host)
 	prefix, err := netip.ParsePrefix(host)
 	if err == nil {
 	if err == nil {
 		prefixes = append(prefixes, prefix.Masked())
 		prefixes = append(prefixes, prefix.Masked())

+ 3 - 3
server/config.go

@@ -143,9 +143,9 @@ type Config struct {
 	VisitorAuthFailureLimitReplenish     time.Duration
 	VisitorAuthFailureLimitReplenish     time.Duration
 	VisitorStatsResetTime                time.Time // Time of the day at which to reset visitor stats
 	VisitorStatsResetTime                time.Time // Time of the day at which to reset visitor stats
 	VisitorSubscriberRateLimiting        bool      // Enable subscriber-based rate limiting for UnifiedPush topics
 	VisitorSubscriberRateLimiting        bool      // Enable subscriber-based rate limiting for UnifiedPush topics
-	BehindProxy                          bool      // If true, the server will trust the proxy client IP header to determine the client IP address
-	ProxyForwardedHeader                 string    // The header field to read the real/client IP address from, if BehindProxy is true, defaults to "X-Forwarded-For"
-	ProxyTrustedAddresses                []string  // List of trusted proxy addresses that will be stripped from the Forwarded header if BehindProxy is true
+	BehindProxy                          bool      // If true, the server will trust the proxy client IP header to determine the client IP address (IPv4 and IPv6 supported)
+	ProxyForwardedHeader                 string    // The header field to read the real/client IP address from, if BehindProxy is true, defaults to "X-Forwarded-For" (IPv4 and IPv6 supported)
+	ProxyTrustedAddresses                []string  // List of trusted proxy addresses (IPv4 or IPv6) that will be stripped from the Forwarded header if BehindProxy is true
 	StripeSecretKey                      string
 	StripeSecretKey                      string
 	StripeWebhookKey                     string
 	StripeWebhookKey                     string
 	StripePriceCacheDuration             time.Duration
 	StripePriceCacheDuration             time.Duration

+ 6 - 5
server/smtp_server.go

@@ -5,8 +5,6 @@ import (
 	"encoding/base64"
 	"encoding/base64"
 	"errors"
 	"errors"
 	"fmt"
 	"fmt"
-	"github.com/emersion/go-smtp"
-	"github.com/microcosm-cc/bluemonday"
 	"io"
 	"io"
 	"mime"
 	"mime"
 	"mime/multipart"
 	"mime/multipart"
@@ -18,6 +16,9 @@ import (
 	"regexp"
 	"regexp"
 	"strings"
 	"strings"
 	"sync"
 	"sync"
+
+	"github.com/emersion/go-smtp"
+	"github.com/microcosm-cc/bluemonday"
 )
 )
 
 
 var (
 var (
@@ -191,9 +192,9 @@ func (s *smtpSession) publishMessage(m *message) error {
 	// Call HTTP handler with fake HTTP request
 	// Call HTTP handler with fake HTTP request
 	url := fmt.Sprintf("%s/%s", s.backend.config.BaseURL, m.Topic)
 	url := fmt.Sprintf("%s/%s", s.backend.config.BaseURL, m.Topic)
 	req, err := http.NewRequest("POST", url, strings.NewReader(m.Message))
 	req, err := http.NewRequest("POST", url, strings.NewReader(m.Message))
-	req.RequestURI = "/" + m.Topic // just for the logs
-	req.RemoteAddr = remoteAddr    // rate limiting!!
-	req.Header.Set("X-Forwarded-For", remoteAddr)
+	req.RequestURI = "/" + m.Topic                                    // just for the logs
+	req.RemoteAddr = remoteAddr                                       // rate limiting!!
+	req.Header.Set(s.backend.config.ProxyForwardedHeader, remoteAddr) // Set X-Forwarded-For header
 	if err != nil {
 	if err != nil {
 		return err
 		return err
 	}
 	}

+ 15 - 8
server/util.go

@@ -4,7 +4,6 @@ import (
 	"context"
 	"context"
 	"errors"
 	"errors"
 	"fmt"
 	"fmt"
-	"heckel.io/ntfy/v2/util"
 	"io"
 	"io"
 	"mime"
 	"mime"
 	"net/http"
 	"net/http"
@@ -12,6 +11,8 @@ import (
 	"regexp"
 	"regexp"
 	"slices"
 	"slices"
 	"strings"
 	"strings"
+
+	"heckel.io/ntfy/v2/util"
 )
 )
 
 
 var (
 var (
@@ -20,8 +21,9 @@ var (
 	// priorityHeaderIgnoreRegex matches specific patterns of the "Priority" header (RFC 9218), so that it can be ignored
 	// priorityHeaderIgnoreRegex matches specific patterns of the "Priority" header (RFC 9218), so that it can be ignored
 	priorityHeaderIgnoreRegex = regexp.MustCompile(`^u=\d,\s*(i|\d)$|^u=\d$`)
 	priorityHeaderIgnoreRegex = regexp.MustCompile(`^u=\d,\s*(i|\d)$|^u=\d$`)
 
 
-	// forwardedHeaderRegex parses IPv4 addresses from the "Forwarded" header (RFC 7239)
-	forwardedHeaderRegex = regexp.MustCompile(`(?i)\bfor="?(\d{1,3}\.\d{1,3}\.\d{1,3}\.\d{1,3})"?`)
+	// forwardedHeaderRegex parses IPv4 and IPv6 addresses from the "Forwarded" header (RFC 7239)
+	// IPv6 addresses in Forwarded header are enclosed in square brackets, e.g. for="[2001:db8::1]"
+	forwardedHeaderRegex = regexp.MustCompile(`(?i)\\bfor=\"?((?:[0-9]{1,3}\.){3}[0-9]{1,3}|\[[0-9a-fA-F:]+\])\"?`)
 )
 )
 
 
 func readBoolParam(r *http.Request, defaultValue bool, names ...string) bool {
 func readBoolParam(r *http.Request, defaultValue bool, names ...string) bool {
@@ -103,7 +105,7 @@ func extractIPAddress(r *http.Request, behindProxy bool, proxyForwardedHeader st
 // then take the right-most address in the list (as this is the one added by our proxy server).
 // then take the right-most address in the list (as this is the one added by our proxy server).
 // See https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/X-Forwarded-For for details.
 // See https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/X-Forwarded-For for details.
 func extractIPAddressFromHeader(r *http.Request, forwardedHeader string, trustedAddresses []string) (netip.Addr, error) {
 func extractIPAddressFromHeader(r *http.Request, forwardedHeader string, trustedAddresses []string) (netip.Addr, error) {
-	value := strings.TrimSpace(strings.ToLower(r.Header.Get(forwardedHeader)))
+	value := strings.TrimSpace(r.Header.Get(forwardedHeader))
 	if value == "" {
 	if value == "" {
 		return netip.IPv4Unspecified(), fmt.Errorf("no %s header found", forwardedHeader)
 		return netip.IPv4Unspecified(), fmt.Errorf("no %s header found", forwardedHeader)
 	}
 	}
@@ -111,12 +113,17 @@ func extractIPAddressFromHeader(r *http.Request, forwardedHeader string, trusted
 	addrsStrs := util.Map(util.SplitNoEmpty(value, ","), strings.TrimSpace)
 	addrsStrs := util.Map(util.SplitNoEmpty(value, ","), strings.TrimSpace)
 	var validAddrs []netip.Addr
 	var validAddrs []netip.Addr
 	for _, addrStr := range addrsStrs {
 	for _, addrStr := range addrsStrs {
-		if addr, err := netip.ParseAddr(addrStr); err == nil {
-			validAddrs = append(validAddrs, addr)
-		} else if m := forwardedHeaderRegex.FindStringSubmatch(addrStr); len(m) == 2 {
-			if addr, err := netip.ParseAddr(m[1]); err == nil {
+		// Handle Forwarded header with for="[IPv6]" or for="IPv4"
+		if m := forwardedHeaderRegex.FindStringSubmatch(addrStr); len(m) == 2 {
+			addrRaw := m[1]
+			if strings.HasPrefix(addrRaw, "[") && strings.HasSuffix(addrRaw, "]") {
+				addrRaw = addrRaw[1 : len(addrRaw)-1]
+			}
+			if addr, err := netip.ParseAddr(addrRaw); err == nil {
 				validAddrs = append(validAddrs, addr)
 				validAddrs = append(validAddrs, addr)
 			}
 			}
+		} else if addr, err := netip.ParseAddr(addrStr); err == nil {
+			validAddrs = append(validAddrs, addr)
 		}
 		}
 	}
 	}
 	// Filter out proxy addresses
 	// Filter out proxy addresses

+ 6 - 0
server/visitor.go

@@ -528,5 +528,11 @@ func visitorID(ip netip.Addr, u *user.User) string {
 	if u != nil && u.Tier != nil {
 	if u != nil && u.Tier != nil {
 		return fmt.Sprintf("user:%s", u.ID)
 		return fmt.Sprintf("user:%s", u.ID)
 	}
 	}
+	if ip.Is6() {
+		// IPv6 addresses are too long to be used as visitor IDs, so we use the first 8 bytes
+		ip = netip.PrefixFrom(ip, 64).Masked().Addr()
+	} else if ip.Is4() {
+		ip = netip.PrefixFrom(ip, 20).Masked().Addr()
+	}
 	return fmt.Sprintf("ip:%s", ip.String())
 	return fmt.Sprintf("ip:%s", ip.String())
 }
 }