Просмотр исходного кода

Merge branch 'ip-range-exempt'

Philipp Heckel 3 лет назад
Родитель
Сommit
cbc912d1e3

+ 31 - 6
cmd/serve.go

@@ -5,16 +5,18 @@ package cmd
 import (
 	"errors"
 	"fmt"
-	"heckel.io/ntfy/log"
 	"io/fs"
 	"math"
 	"net"
+	"net/netip"
 	"os"
 	"os/signal"
 	"strings"
 	"syscall"
 	"time"
 
+	"heckel.io/ntfy/log"
+
 	"github.com/urfave/cli/v2"
 	"github.com/urfave/cli/v2/altsrc"
 	"heckel.io/ntfy/server"
@@ -208,16 +210,14 @@ func execServe(c *cli.Context) error {
 	}
 
 	// Resolve hosts
-	visitorRequestLimitExemptIPs := make([]string, 0)
+	visitorRequestLimitExemptIPs := make([]netip.Prefix, 0)
 	for _, host := range visitorRequestLimitExemptHosts {
-		ips, err := net.LookupIP(host)
+		ips, err := parseIPHostPrefix(host)
 		if err != nil {
 			log.Warn("cannot resolve host %s: %s, ignoring visitor request exemption", host, err.Error())
 			continue
 		}
-		for _, ip := range ips {
-			visitorRequestLimitExemptIPs = append(visitorRequestLimitExemptIPs, ip.String())
-		}
+		visitorRequestLimitExemptIPs = append(visitorRequestLimitExemptIPs, ips...)
 	}
 
 	// Run server
@@ -303,6 +303,31 @@ func sigHandlerConfigReload(config string) {
 	}
 }
 
+func parseIPHostPrefix(host string) (prefixes []netip.Prefix, err error) {
+	// Try parsing as prefix, e.g. 10.0.1.0/24
+	prefix, err := netip.ParsePrefix(host)
+	if err == nil {
+		prefixes = append(prefixes, prefix.Masked())
+		return prefixes, nil
+	}
+	// Not a prefix, parse as host or IP (LookupHost passes through an IP as is)
+	ips, err := net.LookupHost(host)
+	if err != nil {
+		return nil, err
+	}
+	for _, ipStr := range ips {
+		ip, err := netip.ParseAddr(ipStr)
+		if err == nil {
+			prefix, err := ip.Prefix(ip.BitLen())
+			if err != nil {
+				return nil, fmt.Errorf("%s successfully parsed but unable to make prefix: %s", ip.String(), err.Error())
+			}
+			prefixes = append(prefixes, prefix.Masked())
+		}
+	}
+	return
+}
+
 func reloadLogLevel(inputSource altsrc.InputSourceContext) {
 	newLevelStr, err := inputSource.String("log-level")
 	if err != nil {

+ 23 - 5
cmd/serve_test.go

@@ -2,17 +2,19 @@ package cmd
 
 import (
 	"fmt"
-	"github.com/gorilla/websocket"
-	"github.com/stretchr/testify/require"
-	"heckel.io/ntfy/client"
-	"heckel.io/ntfy/test"
-	"heckel.io/ntfy/util"
 	"math/rand"
 	"os"
 	"os/exec"
 	"path/filepath"
 	"testing"
 	"time"
+
+	"github.com/gorilla/websocket"
+	"github.com/stretchr/testify/assert"
+	"github.com/stretchr/testify/require"
+	"heckel.io/ntfy/client"
+	"heckel.io/ntfy/test"
+	"heckel.io/ntfy/util"
 )
 
 func init() {
@@ -70,6 +72,22 @@ func TestCLI_Serve_WebSocket(t *testing.T) {
 	require.Equal(t, "mytopic", m.Topic)
 }
 
+func TestIP_Host_Parsing(t *testing.T) {
+	cases := map[string]string{
+		"1.1.1.1":          "1.1.1.1/32",
+		"fd00::1234":       "fd00::1234/128",
+		"192.168.0.3/24":   "192.168.0.0/24",
+		"10.1.2.3/8":       "10.0.0.0/8",
+		"201:be93::4a6/21": "201:b800::/21",
+	}
+	for q, expectedAnswer := range cases {
+		ips, err := parseIPHostPrefix(q)
+		require.Nil(t, err)
+		assert.Equal(t, 1, len(ips))
+		assert.Equal(t, expectedAnswer, ips[0].String())
+	}
+}
+
 func newEmptyFile(t *testing.T) string {
 	filename := filepath.Join(t.TempDir(), "empty")
 	require.Nil(t, os.WriteFile(filename, []byte{}, 0600))

+ 3 - 2
server/config.go

@@ -2,6 +2,7 @@ package server
 
 import (
 	"io/fs"
+	"net/netip"
 	"time"
 )
 
@@ -92,7 +93,7 @@ type Config struct {
 	VisitorAttachmentDailyBandwidthLimit int
 	VisitorRequestLimitBurst             int
 	VisitorRequestLimitReplenish         time.Duration
-	VisitorRequestExemptIPAddrs          []string
+	VisitorRequestExemptIPAddrs          []netip.Prefix
 	VisitorEmailLimitBurst               int
 	VisitorEmailLimitReplenish           time.Duration
 	BehindProxy                          bool
@@ -135,7 +136,7 @@ func NewConfig() *Config {
 		VisitorAttachmentDailyBandwidthLimit: DefaultVisitorAttachmentDailyBandwidthLimit,
 		VisitorRequestLimitBurst:             DefaultVisitorRequestLimitBurst,
 		VisitorRequestLimitReplenish:         DefaultVisitorRequestLimitReplenish,
-		VisitorRequestExemptIPAddrs:          make([]string, 0),
+		VisitorRequestExemptIPAddrs:          make([]netip.Prefix, 0),
 		VisitorEmailLimitBurst:               DefaultVisitorEmailLimitBurst,
 		VisitorEmailLimitReplenish:           DefaultVisitorEmailLimitReplenish,
 		BehindProxy:                          false,

+ 11 - 4
server/message_cache.go

@@ -5,11 +5,13 @@ import (
 	"encoding/json"
 	"errors"
 	"fmt"
+	"net/netip"
+	"strings"
+	"time"
+
 	_ "github.com/mattn/go-sqlite3" // SQLite driver
 	"heckel.io/ntfy/log"
 	"heckel.io/ntfy/util"
-	"strings"
-	"time"
 )
 
 var (
@@ -279,7 +281,7 @@ func (c *messageCache) addMessages(ms []*message) error {
 			attachmentSize,
 			attachmentExpires,
 			attachmentURL,
-			m.Sender,
+			m.Sender.String(),
 			m.Encoding,
 			published,
 		)
@@ -454,6 +456,11 @@ func readMessages(rows *sql.Rows) ([]*message, error) {
 				return nil, err
 			}
 		}
+		senderIP, err := netip.ParseAddr(sender)
+		if err != nil {
+			senderIP = netip.IPv4Unspecified() // if no IP stored in database, 0.0.0.0
+		}
+
 		var att *attachment
 		if attachmentName != "" && attachmentURL != "" {
 			att = &attachment{
@@ -477,7 +484,7 @@ func readMessages(rows *sql.Rows) ([]*message, error) {
 			Icon:       icon,
 			Actions:    actions,
 			Attachment: att,
-			Sender:     sender,
+			Sender:     senderIP, // Must parse assuming database must be correct
 			Encoding:   encoding,
 		})
 	}

+ 13 - 7
server/message_cache_test.go

@@ -3,11 +3,17 @@ package server
 import (
 	"database/sql"
 	"fmt"
-	"github.com/stretchr/testify/assert"
-	"github.com/stretchr/testify/require"
+	"net/netip"
 	"path/filepath"
 	"testing"
 	"time"
+
+	"github.com/stretchr/testify/assert"
+	"github.com/stretchr/testify/require"
+)
+
+var (
+	exampleIP1234 = netip.MustParseAddr("1.2.3.4")
 )
 
 func TestSqliteCache_Messages(t *testing.T) {
@@ -281,7 +287,7 @@ func testCacheAttachments(t *testing.T, c *messageCache) {
 	expires1 := time.Now().Add(-4 * time.Hour).Unix()
 	m := newDefaultMessage("mytopic", "flower for you")
 	m.ID = "m1"
-	m.Sender = "1.2.3.4"
+	m.Sender = exampleIP1234
 	m.Attachment = &attachment{
 		Name:    "flower.jpg",
 		Type:    "image/jpeg",
@@ -294,7 +300,7 @@ func testCacheAttachments(t *testing.T, c *messageCache) {
 	expires2 := time.Now().Add(2 * time.Hour).Unix() // Future
 	m = newDefaultMessage("mytopic", "sending you a car")
 	m.ID = "m2"
-	m.Sender = "1.2.3.4"
+	m.Sender = exampleIP1234
 	m.Attachment = &attachment{
 		Name:    "car.jpg",
 		Type:    "image/jpeg",
@@ -307,7 +313,7 @@ func testCacheAttachments(t *testing.T, c *messageCache) {
 	expires3 := time.Now().Add(1 * time.Hour).Unix() // Future
 	m = newDefaultMessage("another-topic", "sending you another car")
 	m.ID = "m3"
-	m.Sender = "1.2.3.4"
+	m.Sender = exampleIP1234
 	m.Attachment = &attachment{
 		Name:    "another-car.jpg",
 		Type:    "image/jpeg",
@@ -327,7 +333,7 @@ func testCacheAttachments(t *testing.T, c *messageCache) {
 	require.Equal(t, int64(5000), messages[0].Attachment.Size)
 	require.Equal(t, expires1, messages[0].Attachment.Expires)
 	require.Equal(t, "https://ntfy.sh/file/AbDeFgJhal.jpg", messages[0].Attachment.URL)
-	require.Equal(t, "1.2.3.4", messages[0].Sender)
+	require.Equal(t, "1.2.3.4", messages[0].Sender.String())
 
 	require.Equal(t, "sending you a car", messages[1].Message)
 	require.Equal(t, "car.jpg", messages[1].Attachment.Name)
@@ -335,7 +341,7 @@ func testCacheAttachments(t *testing.T, c *messageCache) {
 	require.Equal(t, int64(10000), messages[1].Attachment.Size)
 	require.Equal(t, expires2, messages[1].Attachment.Expires)
 	require.Equal(t, "https://ntfy.sh/file/aCaRURL.jpg", messages[1].Attachment.URL)
-	require.Equal(t, "1.2.3.4", messages[1].Sender)
+	require.Equal(t, "1.2.3.4", messages[1].Sender.String())
 
 	size, err := c.AttachmentBytesUsed("1.2.3.4")
 	require.Nil(t, err)

+ 21 - 8
server/server.go

@@ -11,6 +11,7 @@ import (
 	"io"
 	"net"
 	"net/http"
+	"net/netip"
 	"net/url"
 	"os"
 	"path"
@@ -42,7 +43,7 @@ type Server struct {
 	smtpServerBackend *smtpBackend
 	smtpSender        mailer
 	topics            map[string]*topic
-	visitors          map[string]*visitor
+	visitors          map[netip.Addr]*visitor
 	firebaseClient    *firebaseClient
 	messages          int64
 	auth              auth.Auther
@@ -150,7 +151,7 @@ func New(conf *Config) (*Server, error) {
 		smtpSender:     mailer,
 		topics:         topics,
 		auth:           auther,
-		visitors:       make(map[string]*visitor),
+		visitors:       make(map[netip.Addr]*visitor),
 	}, nil
 }
 
@@ -1219,7 +1220,7 @@ func (s *Server) runFirebaseKeepaliver() {
 	if s.firebaseClient == nil {
 		return
 	}
-	v := newVisitor(s.config, s.messageCache, "0.0.0.0") // Background process, not a real visitor
+	v := newVisitor(s.config, s.messageCache, netip.IPv4Unspecified()) // Background process, not a real visitor, uses IP 0.0.0.0
 	for {
 		select {
 		case <-time.After(s.config.FirebaseKeepaliveInterval):
@@ -1286,7 +1287,7 @@ func (s *Server) sendDelayedMessage(v *visitor, m *message) error {
 
 func (s *Server) limitRequests(next handleFunc) handleFunc {
 	return func(w http.ResponseWriter, r *http.Request, v *visitor) error {
-		if util.Contains(s.config.VisitorRequestExemptIPAddrs, v.ip) {
+		if util.ContainsIP(s.config.VisitorRequestExemptIPAddrs, v.ip) {
 			return next(w, r, v)
 		} else if err := v.RequestAllowed(); err != nil {
 			return errHTTPTooManyRequestsLimitRequests
@@ -1436,21 +1437,33 @@ func extractUserPass(r *http.Request) (username string, password string, ok bool
 // This function was taken from https://www.alexedwards.net/blog/how-to-rate-limit-http-requests (MIT).
 func (s *Server) visitor(r *http.Request) *visitor {
 	remoteAddr := r.RemoteAddr
-	ip, _, err := net.SplitHostPort(remoteAddr)
+	addrPort, err := netip.ParseAddrPort(remoteAddr)
+	ip := addrPort.Addr()
 	if err != nil {
-		ip = remoteAddr // This should not happen in real life; only in tests.
+		// This should not happen in real life; only in tests. So, using falling back to 0.0.0.0 if address unspecified
+		ip, err = netip.ParseAddr(remoteAddr)
+		if err != nil {
+			ip = netip.IPv4Unspecified()
+			log.Warn("unable to parse IP (%s), new visitor with unspecified IP (0.0.0.0) created %s", remoteAddr, err)
+		}
 	}
 	if s.config.BehindProxy && strings.TrimSpace(r.Header.Get("X-Forwarded-For")) != "" {
 		// X-Forwarded-For can contain multiple addresses (see #328). If we are behind a proxy,
 		// only the right-most address can be trusted (as this is the one added by our proxy server).
 		// See https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/X-Forwarded-For for details.
 		ips := util.SplitNoEmpty(r.Header.Get("X-Forwarded-For"), ",")
-		ip = strings.TrimSpace(util.LastString(ips, remoteAddr))
+		realIP, err := netip.ParseAddr(strings.TrimSpace(util.LastString(ips, remoteAddr)))
+		if err != nil {
+			log.Error("invalid IP address %s received in X-Forwarded-For header: %s", ip, err.Error())
+			// Fall back to regular remote address if X-Forwarded-For is damaged
+		} else {
+			ip = realIP
+		}
 	}
 	return s.visitorFromIP(ip)
 }
 
-func (s *Server) visitorFromIP(ip string) *visitor {
+func (s *Server) visitorFromIP(ip netip.Addr) *visitor {
 	s.mu.Lock()
 	defer s.mu.Unlock()
 	v, exists := s.visitors[ip]

+ 6 - 4
server/server_firebase_test.go

@@ -3,13 +3,15 @@ package server
 import (
 	"encoding/json"
 	"errors"
-	"firebase.google.com/go/v4/messaging"
 	"fmt"
-	"github.com/stretchr/testify/require"
-	"heckel.io/ntfy/auth"
+	"net/netip"
 	"strings"
 	"sync"
 	"testing"
+
+	"firebase.google.com/go/v4/messaging"
+	"github.com/stretchr/testify/require"
+	"heckel.io/ntfy/auth"
 )
 
 type testAuther struct {
@@ -322,7 +324,7 @@ func TestMaybeTruncateFCMMessage_NotTooLong(t *testing.T) {
 func TestToFirebaseSender_Abuse(t *testing.T) {
 	sender := &testFirebaseSender{allowed: 2}
 	client := newFirebaseClient(sender, &testAuther{})
-	visitor := newVisitor(newTestConfig(t), newMemTestCache(t), "1.2.3.4")
+	visitor := newVisitor(newTestConfig(t), newMemTestCache(t), netip.MustParseAddr("1.2.3.4"))
 
 	require.Nil(t, client.Send(visitor, &message{Topic: "mytopic"}))
 	require.Equal(t, 1, len(sender.Messages()))

+ 4 - 2
server/server_matrix_test.go

@@ -1,11 +1,13 @@
 package server
 
 import (
-	"github.com/stretchr/testify/require"
 	"net/http"
 	"net/http/httptest"
+	"net/netip"
 	"strings"
 	"testing"
+
+	"github.com/stretchr/testify/require"
 )
 
 func TestMatrix_NewRequestFromMatrixJSON_Success(t *testing.T) {
@@ -70,7 +72,7 @@ func TestMatrix_WriteMatrixDiscoveryResponse(t *testing.T) {
 func TestMatrix_WriteMatrixError(t *testing.T) {
 	w := httptest.NewRecorder()
 	r, _ := http.NewRequest("POST", "http://ntfy.example.com/_matrix/push/v1/notify", nil)
-	v := newVisitor(newTestConfig(t), nil, "1.2.3.4")
+	v := newVisitor(newTestConfig(t), nil, netip.MustParseAddr("1.2.3.4"))
 	require.Nil(t, writeMatrixError(w, r, v, &errMatrix{"https://ntfy.example.com/upABCDEFGHI?up=1", errHTTPBadRequestMatrixPushkeyBaseURLMismatch}))
 	require.Equal(t, 200, w.Result().StatusCode)
 	require.Equal(t, `{"rejected":["https://ntfy.example.com/upABCDEFGHI?up=1"]}`+"\n", w.Body.String())

+ 13 - 11
server/server_test.go

@@ -6,18 +6,20 @@ import (
 	"encoding/base64"
 	"encoding/json"
 	"fmt"
-	"github.com/stretchr/testify/assert"
 	"io"
 	"log"
 	"math/rand"
 	"net/http"
 	"net/http/httptest"
+	"net/netip"
 	"path/filepath"
 	"strings"
 	"sync"
 	"testing"
 	"time"
 
+	"github.com/stretchr/testify/assert"
+
 	"github.com/stretchr/testify/require"
 	"heckel.io/ntfy/auth"
 	"heckel.io/ntfy/util"
@@ -292,13 +294,13 @@ func TestServer_PublishAt(t *testing.T) {
 	messages = toMessages(t, response.Body.String())
 	require.Equal(t, 1, len(messages))
 	require.Equal(t, "a message", messages[0].Message)
-	require.Equal(t, "", messages[0].Sender) // Never return the sender!
+	require.Equal(t, netip.Addr{}, messages[0].Sender) // Never return the sender!
 
 	messages, err := s.messageCache.Messages("mytopic", sinceAllMessages, true)
 	require.Nil(t, err)
 	require.Equal(t, 1, len(messages))
 	require.Equal(t, "a message", messages[0].Message)
-	require.Equal(t, "9.9.9.9", messages[0].Sender) // It's stored in the DB though!
+	require.Equal(t, "9.9.9.9", messages[0].Sender.String()) // It's stored in the DB though!
 }
 
 func TestServer_PublishAtWithCacheError(t *testing.T) {
@@ -814,7 +816,7 @@ func TestServer_PublishTooRequests_Defaults(t *testing.T) {
 
 func TestServer_PublishTooRequests_Defaults_ExemptHosts(t *testing.T) {
 	c := newTestConfig(t)
-	c.VisitorRequestExemptIPAddrs = []string{"9.9.9.9"} // see request()
+	c.VisitorRequestExemptIPAddrs = []netip.Prefix{netip.MustParsePrefix("9.9.9.9/32")} // see request()
 	s := newTestServer(t, c)
 	for i := 0; i < 65; i++ { // > 60
 		response := request(t, s, "PUT", "/mytopic", fmt.Sprintf("message %d", i), nil)
@@ -1132,7 +1134,7 @@ func TestServer_PublishAttachment(t *testing.T) {
 	require.Equal(t, int64(5000), msg.Attachment.Size)
 	require.GreaterOrEqual(t, msg.Attachment.Expires, time.Now().Add(179*time.Minute).Unix()) // Almost 3 hours
 	require.Contains(t, msg.Attachment.URL, "http://127.0.0.1:12345/file/")
-	require.Equal(t, "", msg.Sender) // Should never be returned
+	require.Equal(t, netip.Addr{}, msg.Sender) // Should never be returned
 	require.FileExists(t, filepath.Join(s.config.AttachmentCacheDir, msg.ID))
 
 	// GET
@@ -1168,7 +1170,7 @@ func TestServer_PublishAttachmentShortWithFilename(t *testing.T) {
 	require.Equal(t, int64(21), msg.Attachment.Size)
 	require.GreaterOrEqual(t, msg.Attachment.Expires, time.Now().Add(3*time.Hour).Unix())
 	require.Contains(t, msg.Attachment.URL, "http://127.0.0.1:12345/file/")
-	require.Equal(t, "", msg.Sender) // Should never be returned
+	require.Equal(t, netip.Addr{}, msg.Sender) // Should never be returned
 	require.FileExists(t, filepath.Join(s.config.AttachmentCacheDir, msg.ID))
 
 	path := strings.TrimPrefix(msg.Attachment.URL, "http://127.0.0.1:12345")
@@ -1195,7 +1197,7 @@ func TestServer_PublishAttachmentExternalWithoutFilename(t *testing.T) {
 	require.Equal(t, "", msg.Attachment.Type)
 	require.Equal(t, int64(0), msg.Attachment.Size)
 	require.Equal(t, int64(0), msg.Attachment.Expires)
-	require.Equal(t, "", msg.Sender)
+	require.Equal(t, netip.Addr{}, msg.Sender)
 
 	// Slightly unrelated cross-test: make sure we don't add an owner for external attachments
 	size, err := s.messageCache.AttachmentBytesUsed("127.0.0.1")
@@ -1216,7 +1218,7 @@ func TestServer_PublishAttachmentExternalWithFilename(t *testing.T) {
 	require.Equal(t, "", msg.Attachment.Type)
 	require.Equal(t, int64(0), msg.Attachment.Size)
 	require.Equal(t, int64(0), msg.Attachment.Expires)
-	require.Equal(t, "", msg.Sender)
+	require.Equal(t, netip.Addr{}, msg.Sender)
 }
 
 func TestServer_PublishAttachmentBadURL(t *testing.T) {
@@ -1391,7 +1393,7 @@ func TestServer_Visitor_XForwardedFor_None(t *testing.T) {
 	r.RemoteAddr = "8.9.10.11"
 	r.Header.Set("X-Forwarded-For", "  ") // Spaces, not empty!
 	v := s.visitor(r)
-	require.Equal(t, "8.9.10.11", v.ip)
+	require.Equal(t, "8.9.10.11", v.ip.String())
 }
 
 func TestServer_Visitor_XForwardedFor_Single(t *testing.T) {
@@ -1402,7 +1404,7 @@ func TestServer_Visitor_XForwardedFor_Single(t *testing.T) {
 	r.RemoteAddr = "8.9.10.11"
 	r.Header.Set("X-Forwarded-For", "1.1.1.1")
 	v := s.visitor(r)
-	require.Equal(t, "1.1.1.1", v.ip)
+	require.Equal(t, "1.1.1.1", v.ip.String())
 }
 
 func TestServer_Visitor_XForwardedFor_Multiple(t *testing.T) {
@@ -1413,7 +1415,7 @@ func TestServer_Visitor_XForwardedFor_Multiple(t *testing.T) {
 	r.RemoteAddr = "8.9.10.11"
 	r.Header.Set("X-Forwarded-For", "1.2.3.4 , 2.4.4.2,234.5.2.1 ")
 	v := s.visitor(r)
-	require.Equal(t, "234.5.2.1", v.ip)
+	require.Equal(t, "234.5.2.1", v.ip.String())
 }
 
 func TestServer_PublishWhileUpdatingStatsWithLotsOfMessages(t *testing.T) {

+ 1 - 1
server/smtp_sender.go

@@ -32,7 +32,7 @@ func (s *smtpSender) Send(v *visitor, m *message, to string) error {
 		if err != nil {
 			return err
 		}
-		message, err := formatMail(s.config.BaseURL, v.ip, s.config.SMTPSenderFrom, to, m)
+		message, err := formatMail(s.config.BaseURL, v.ip.String(), s.config.SMTPSenderFrom, to, m)
 		if err != nil {
 			return err
 		}

+ 4 - 2
server/types.go

@@ -1,9 +1,11 @@
 package server
 
 import (
-	"heckel.io/ntfy/util"
 	"net/http"
+	"net/netip"
 	"time"
+
+	"heckel.io/ntfy/util"
 )
 
 // List of possible events
@@ -33,7 +35,7 @@ type message struct {
 	Actions    []*action   `json:"actions,omitempty"`
 	Attachment *attachment `json:"attachment,omitempty"`
 	PollID     string      `json:"poll_id,omitempty"`
-	Sender     string      `json:"-"`                  // IP address of uploader, used for rate limiting
+	Sender     netip.Addr  `json:"-"`                  // IP address of uploader, used for rate limiting
 	Encoding   string      `json:"encoding,omitempty"` // empty for raw UTF-8, or "base64" for encoded bytes
 }
 

+ 7 - 5
server/visitor.go

@@ -2,10 +2,12 @@ package server
 
 import (
 	"errors"
-	"golang.org/x/time/rate"
-	"heckel.io/ntfy/util"
+	"net/netip"
 	"sync"
 	"time"
+
+	"golang.org/x/time/rate"
+	"heckel.io/ntfy/util"
 )
 
 const (
@@ -23,7 +25,7 @@ var (
 type visitor struct {
 	config        *Config
 	messageCache  *messageCache
-	ip            string
+	ip            netip.Addr
 	requests      *rate.Limiter
 	emails        *rate.Limiter
 	subscriptions util.Limiter
@@ -40,7 +42,7 @@ type visitorStats struct {
 	VisitorAttachmentBytesRemaining int64 `json:"visitorAttachmentBytesRemaining"`
 }
 
-func newVisitor(conf *Config, messageCache *messageCache, ip string) *visitor {
+func newVisitor(conf *Config, messageCache *messageCache, ip netip.Addr) *visitor {
 	return &visitor{
 		config:        conf,
 		messageCache:  messageCache,
@@ -115,7 +117,7 @@ func (v *visitor) Stale() bool {
 }
 
 func (v *visitor) Stats() (*visitorStats, error) {
-	attachmentsBytesUsed, err := v.messageCache.AttachmentBytesUsed(v.ip)
+	attachmentsBytesUsed, err := v.messageCache.AttachmentBytesUsed(v.ip.String())
 	if err != nil {
 		return nil, err
 	}

+ 14 - 2
util/util.go

@@ -5,16 +5,18 @@ import (
 	"encoding/json"
 	"errors"
 	"fmt"
-	"github.com/gabriel-vasile/mimetype"
-	"golang.org/x/term"
 	"io"
 	"math/rand"
+	"net/netip"
 	"os"
 	"regexp"
 	"strconv"
 	"strings"
 	"sync"
 	"time"
+
+	"github.com/gabriel-vasile/mimetype"
+	"golang.org/x/term"
 )
 
 const (
@@ -45,6 +47,16 @@ func Contains[T comparable](haystack []T, needle T) bool {
 	return false
 }
 
+// ContainsIP returns true if any one of the of prefixes contains the ip.
+func ContainsIP(haystack []netip.Prefix, needle netip.Addr) bool {
+	for _, s := range haystack {
+		if s.Contains(needle) {
+			return true
+		}
+	}
+	return false
+}
+
 // ContainsAll returns true if all needles are contained in haystack
 func ContainsAll[T comparable](haystack []T, needles []T) bool {
 	matches := 0

+ 10 - 1
util/util_test.go

@@ -1,10 +1,12 @@
 package util
 
 import (
-	"github.com/stretchr/testify/require"
+	"net/netip"
 	"os"
 	"path/filepath"
 	"testing"
+
+	"github.com/stretchr/testify/require"
 )
 
 func TestRandomString(t *testing.T) {
@@ -42,6 +44,13 @@ func TestContains(t *testing.T) {
 	require.False(t, Contains(s, 3))
 }
 
+func TestContainsIP(t *testing.T) {
+	require.True(t, ContainsIP([]netip.Prefix{netip.MustParsePrefix("fd00::/8"), netip.MustParsePrefix("1.1.0.0/16")}, netip.MustParseAddr("1.1.1.1")))
+	require.True(t, ContainsIP([]netip.Prefix{netip.MustParsePrefix("fd00::/8"), netip.MustParsePrefix("1.1.0.0/16")}, netip.MustParseAddr("fd12:1234:5678::9876")))
+	require.False(t, ContainsIP([]netip.Prefix{netip.MustParsePrefix("fd00::/8"), netip.MustParsePrefix("1.1.0.0/16")}, netip.MustParseAddr("1.2.0.1")))
+	require.False(t, ContainsIP([]netip.Prefix{netip.MustParsePrefix("fd00::/8"), netip.MustParsePrefix("1.1.0.0/16")}, netip.MustParseAddr("fc00::1")))
+}
+
 func TestSplitNoEmpty(t *testing.T) {
 	require.Equal(t, []string{}, SplitNoEmpty("", ","))
 	require.Equal(t, []string{}, SplitNoEmpty(",,,", ","))