Просмотр исходного кода

refactor visitor IPs and allow exempting IP Ranges

Use netip.Addr instead of storing addresses as strings. This requires
conversions at the database level and in tests, but is more memory
efficient otherwise, and facilitates the following.

Parse rate limit exemptions as netip.Prefix. This allows storing IP
ranges in the exemption list. Regular IP addresses (entered explicitly
or resolved from hostnames) are IPV4/32, denoting a range of one
address.
Karmanyaah Malhotra 3 лет назад
Родитель
Сommit
c2382d29a1

+ 33 - 4
cmd/serve.go

@@ -5,16 +5,18 @@ package cmd
 import (
 	"errors"
 	"fmt"
-	"heckel.io/ntfy/log"
 	"io/fs"
 	"math"
 	"net"
+	"net/netip"
 	"os"
 	"os/signal"
 	"strings"
 	"syscall"
 	"time"
 
+	"heckel.io/ntfy/log"
+
 	"github.com/urfave/cli/v2"
 	"github.com/urfave/cli/v2/altsrc"
 	"heckel.io/ntfy/server"
@@ -208,15 +210,15 @@ func execServe(c *cli.Context) error {
 	}
 
 	// Resolve hosts
-	visitorRequestLimitExemptIPs := make([]string, 0)
+	visitorRequestLimitExemptIPs := make([]netip.Prefix, 0)
 	for _, host := range visitorRequestLimitExemptHosts {
-		ips, err := net.LookupIP(host)
+		ips, err := parseIPHostPrefix(host)
 		if err != nil {
 			log.Warn("cannot resolve host %s: %s, ignoring visitor request exemption", host, err.Error())
 			continue
 		}
 		for _, ip := range ips {
-			visitorRequestLimitExemptIPs = append(visitorRequestLimitExemptIPs, ip.String())
+			visitorRequestLimitExemptIPs = append(visitorRequestLimitExemptIPs, ip)
 		}
 	}
 
@@ -303,6 +305,33 @@ func sigHandlerConfigReload(config string) {
 	}
 }
 
+func parseIPHostPrefix(host string) (prefixes []netip.Prefix, err error) {
+        //try parsing as prefix
+        prefix, err := netip.ParsePrefix(host)
+        if err == nil {
+                prefixes = append(prefixes, prefix.Masked()) // masked and canonical for easy of debugging, shouldn't matter
+                return prefixes, nil                         // success
+        }
+
+        // not a prefix, parse as host or IP
+        // LookupHost forwards through if it's an IP
+        ips, err := net.LookupHost(host)
+        if err == nil {
+                for _, i := range ips {
+                        ip, err := netip.ParseAddr(i)
+                        if err == nil {
+                                prefix, err := ip.Prefix(ip.BitLen())
+                                if err != nil {
+                                        return prefixes, errors.New(fmt.Sprint("ip", ip, " successfully parsed as IP but unable to turn into prefix. THIS SHOULD NEVER HAPPEN. err:", err.Error()))
+                                }
+                                prefixes = append(prefixes, prefix.Masked()) //also masked canonical ip
+                        }
+                }
+        }
+        return
+}
+
+
 func reloadLogLevel(inputSource altsrc.InputSourceContext) {
 	newLevelStr, err := inputSource.String("log-level")
 	if err != nil {

+ 3 - 2
server/config.go

@@ -2,6 +2,7 @@ package server
 
 import (
 	"io/fs"
+	"net/netip"
 	"time"
 )
 
@@ -92,7 +93,7 @@ type Config struct {
 	VisitorAttachmentDailyBandwidthLimit int
 	VisitorRequestLimitBurst             int
 	VisitorRequestLimitReplenish         time.Duration
-	VisitorRequestExemptIPAddrs          []string
+	VisitorRequestExemptIPAddrs          []netip.Prefix
 	VisitorEmailLimitBurst               int
 	VisitorEmailLimitReplenish           time.Duration
 	BehindProxy                          bool
@@ -135,7 +136,7 @@ func NewConfig() *Config {
 		VisitorAttachmentDailyBandwidthLimit: DefaultVisitorAttachmentDailyBandwidthLimit,
 		VisitorRequestLimitBurst:             DefaultVisitorRequestLimitBurst,
 		VisitorRequestLimitReplenish:         DefaultVisitorRequestLimitReplenish,
-		VisitorRequestExemptIPAddrs:          make([]string, 0),
+		VisitorRequestExemptIPAddrs:          make([]netip.Prefix, 0),
 		VisitorEmailLimitBurst:               DefaultVisitorEmailLimitBurst,
 		VisitorEmailLimitReplenish:           DefaultVisitorEmailLimitReplenish,
 		BehindProxy:                          false,

+ 6 - 4
server/message_cache.go

@@ -5,11 +5,13 @@ import (
 	"encoding/json"
 	"errors"
 	"fmt"
+	"net/netip"
+	"strings"
+	"time"
+
 	_ "github.com/mattn/go-sqlite3" // SQLite driver
 	"heckel.io/ntfy/log"
 	"heckel.io/ntfy/util"
-	"strings"
-	"time"
 )
 
 var (
@@ -279,7 +281,7 @@ func (c *messageCache) addMessages(ms []*message) error {
 			attachmentSize,
 			attachmentExpires,
 			attachmentURL,
-			m.Sender,
+			m.Sender.String(),
 			m.Encoding,
 			published,
 		)
@@ -477,7 +479,7 @@ func readMessages(rows *sql.Rows) ([]*message, error) {
 			Icon:       icon,
 			Actions:    actions,
 			Attachment: att,
-			Sender:     sender,
+			Sender:     netip.MustParseAddr(sender), // Must parse assuming database must be correct
 			Encoding:   encoding,
 		})
 	}

+ 7 - 5
server/message_cache_test.go

@@ -3,11 +3,13 @@ package server
 import (
 	"database/sql"
 	"fmt"
-	"github.com/stretchr/testify/assert"
-	"github.com/stretchr/testify/require"
+	"net/netip"
 	"path/filepath"
 	"testing"
 	"time"
+
+	"github.com/stretchr/testify/assert"
+	"github.com/stretchr/testify/require"
 )
 
 func TestSqliteCache_Messages(t *testing.T) {
@@ -281,7 +283,7 @@ func testCacheAttachments(t *testing.T, c *messageCache) {
 	expires1 := time.Now().Add(-4 * time.Hour).Unix()
 	m := newDefaultMessage("mytopic", "flower for you")
 	m.ID = "m1"
-	m.Sender = "1.2.3.4"
+	m.Sender = netip.MustParseAddr("1.2.3.4")
 	m.Attachment = &attachment{
 		Name:    "flower.jpg",
 		Type:    "image/jpeg",
@@ -294,7 +296,7 @@ func testCacheAttachments(t *testing.T, c *messageCache) {
 	expires2 := time.Now().Add(2 * time.Hour).Unix() // Future
 	m = newDefaultMessage("mytopic", "sending you a car")
 	m.ID = "m2"
-	m.Sender = "1.2.3.4"
+	m.Sender = netip.MustParseAddr("1.2.3.4")
 	m.Attachment = &attachment{
 		Name:    "car.jpg",
 		Type:    "image/jpeg",
@@ -307,7 +309,7 @@ func testCacheAttachments(t *testing.T, c *messageCache) {
 	expires3 := time.Now().Add(1 * time.Hour).Unix() // Future
 	m = newDefaultMessage("another-topic", "sending you another car")
 	m.ID = "m3"
-	m.Sender = "1.2.3.4"
+	m.Sender = netip.MustParseAddr("1.2.3.4")
 	m.Attachment = &attachment{
 		Name:    "another-car.jpg",
 		Type:    "image/jpeg",

+ 18 - 9
server/server.go

@@ -11,6 +11,7 @@ import (
 	"io"
 	"net"
 	"net/http"
+	"net/netip"
 	"net/url"
 	"os"
 	"path"
@@ -42,7 +43,7 @@ type Server struct {
 	smtpServerBackend *smtpBackend
 	smtpSender        mailer
 	topics            map[string]*topic
-	visitors          map[string]*visitor
+	visitors          map[netip.Addr]*visitor
 	firebaseClient    *firebaseClient
 	messages          int64
 	auth              auth.Auther
@@ -150,7 +151,7 @@ func New(conf *Config) (*Server, error) {
 		smtpSender:     mailer,
 		topics:         topics,
 		auth:           auther,
-		visitors:       make(map[string]*visitor),
+		visitors:       make(map[netip.Addr]*visitor),
 	}, nil
 }
 
@@ -642,8 +643,8 @@ func (s *Server) parsePublishParams(r *http.Request, v *visitor, m *message) (ca
 			return false, false, "", false, errHTTPBadRequestDelayTooLarge
 		}
 		m.Time = delay.Unix()
-		m.Sender = v.ip // Important for rate limiting
 	}
+	m.Sender = v.ip // Important for rate limiting
 	actionsStr := readParam(r, "x-actions", "actions", "action")
 	if actionsStr != "" {
 		m.Actions, err = parseActions(actionsStr)
@@ -1219,7 +1220,7 @@ func (s *Server) runFirebaseKeepaliver() {
 	if s.firebaseClient == nil {
 		return
 	}
-	v := newVisitor(s.config, s.messageCache, "0.0.0.0") // Background process, not a real visitor
+	v := newVisitor(s.config, s.messageCache, netip.MustParseAddr("0.0.0.0")) // Background process, not a real visitor
 	for {
 		select {
 		case <-time.After(s.config.FirebaseKeepaliveInterval):
@@ -1286,7 +1287,7 @@ func (s *Server) sendDelayedMessage(v *visitor, m *message) error {
 
 func (s *Server) limitRequests(next handleFunc) handleFunc {
 	return func(w http.ResponseWriter, r *http.Request, v *visitor) error {
-		if util.Contains(s.config.VisitorRequestExemptIPAddrs, v.ip) {
+		if util.ContainsContains(s.config.VisitorRequestExemptIPAddrs, v.ip) {
 			return next(w, r, v)
 		} else if err := v.RequestAllowed(); err != nil {
 			return errHTTPTooManyRequestsLimitRequests
@@ -1436,21 +1437,29 @@ func extractUserPass(r *http.Request) (username string, password string, ok bool
 // This function was taken from https://www.alexedwards.net/blog/how-to-rate-limit-http-requests (MIT).
 func (s *Server) visitor(r *http.Request) *visitor {
 	remoteAddr := r.RemoteAddr
-	ip, _, err := net.SplitHostPort(remoteAddr)
+	ipport, err := netip.ParseAddrPort(remoteAddr)
+	ip := ipport.Addr()
 	if err != nil {
-		ip = remoteAddr // This should not happen in real life; only in tests.
+		ip = netip.MustParseAddr(remoteAddr) // This should not happen in real life; only in tests. So, using MustParse, which panics on error.
 	}
 	if s.config.BehindProxy && strings.TrimSpace(r.Header.Get("X-Forwarded-For")) != "" {
 		// X-Forwarded-For can contain multiple addresses (see #328). If we are behind a proxy,
 		// only the right-most address can be trusted (as this is the one added by our proxy server).
 		// See https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/X-Forwarded-For for details.
 		ips := util.SplitNoEmpty(r.Header.Get("X-Forwarded-For"), ",")
-		ip = strings.TrimSpace(util.LastString(ips, remoteAddr))
+		myip, err := netip.ParseAddr(strings.TrimSpace(util.LastString(ips, remoteAddr)))
+		if err != nil {
+			log.Error("Invalid IP Address Received from proxy in X-Forwarded-For header. This should NEVER happen, your proxy is seriously broken: ", ip, err)
+			// fall back to regular remote address if x forwarded for is damaged
+		} else {
+			ip = myip
+		}
+
 	}
 	return s.visitorFromIP(ip)
 }
 
-func (s *Server) visitorFromIP(ip string) *visitor {
+func (s *Server) visitorFromIP(ip netip.Addr) *visitor {
 	s.mu.Lock()
 	defer s.mu.Unlock()
 	v, exists := s.visitors[ip]

+ 6 - 4
server/server_firebase_test.go

@@ -3,13 +3,15 @@ package server
 import (
 	"encoding/json"
 	"errors"
-	"firebase.google.com/go/v4/messaging"
 	"fmt"
-	"github.com/stretchr/testify/require"
-	"heckel.io/ntfy/auth"
+	"net/netip"
 	"strings"
 	"sync"
 	"testing"
+
+	"firebase.google.com/go/v4/messaging"
+	"github.com/stretchr/testify/require"
+	"heckel.io/ntfy/auth"
 )
 
 type testAuther struct {
@@ -322,7 +324,7 @@ func TestMaybeTruncateFCMMessage_NotTooLong(t *testing.T) {
 func TestToFirebaseSender_Abuse(t *testing.T) {
 	sender := &testFirebaseSender{allowed: 2}
 	client := newFirebaseClient(sender, &testAuther{})
-	visitor := newVisitor(newTestConfig(t), newMemTestCache(t), "1.2.3.4")
+	visitor := newVisitor(newTestConfig(t), newMemTestCache(t), netip.MustParseAddr("1.2.3.4"))
 
 	require.Nil(t, client.Send(visitor, &message{Topic: "mytopic"}))
 	require.Equal(t, 1, len(sender.Messages()))

+ 4 - 2
server/server_matrix_test.go

@@ -1,11 +1,13 @@
 package server
 
 import (
-	"github.com/stretchr/testify/require"
 	"net/http"
 	"net/http/httptest"
+	"net/netip"
 	"strings"
 	"testing"
+
+	"github.com/stretchr/testify/require"
 )
 
 func TestMatrix_NewRequestFromMatrixJSON_Success(t *testing.T) {
@@ -70,7 +72,7 @@ func TestMatrix_WriteMatrixDiscoveryResponse(t *testing.T) {
 func TestMatrix_WriteMatrixError(t *testing.T) {
 	w := httptest.NewRecorder()
 	r, _ := http.NewRequest("POST", "http://ntfy.example.com/_matrix/push/v1/notify", nil)
-	v := newVisitor(newTestConfig(t), nil, "1.2.3.4")
+	v := newVisitor(newTestConfig(t), nil, netip.MustParseAddr("1.2.3.4"))
 	require.Nil(t, writeMatrixError(w, r, v, &errMatrix{"https://ntfy.example.com/upABCDEFGHI?up=1", errHTTPBadRequestMatrixPushkeyBaseURLMismatch}))
 	require.Equal(t, 200, w.Result().StatusCode)
 	require.Equal(t, `{"rejected":["https://ntfy.example.com/upABCDEFGHI?up=1"]}`+"\n", w.Body.String())

+ 4 - 2
server/server_test.go

@@ -6,18 +6,20 @@ import (
 	"encoding/base64"
 	"encoding/json"
 	"fmt"
-	"github.com/stretchr/testify/assert"
 	"io"
 	"log"
 	"math/rand"
 	"net/http"
 	"net/http/httptest"
+	"net/netip"
 	"path/filepath"
 	"strings"
 	"sync"
 	"testing"
 	"time"
 
+	"github.com/stretchr/testify/assert"
+
 	"github.com/stretchr/testify/require"
 	"heckel.io/ntfy/auth"
 	"heckel.io/ntfy/util"
@@ -814,7 +816,7 @@ func TestServer_PublishTooRequests_Defaults(t *testing.T) {
 
 func TestServer_PublishTooRequests_Defaults_ExemptHosts(t *testing.T) {
 	c := newTestConfig(t)
-	c.VisitorRequestExemptIPAddrs = []string{"9.9.9.9"} // see request()
+	c.VisitorRequestExemptIPAddrs = []netip.Prefix{netip.MustParsePrefix("9.9.9.9/32")} // see request()
 	s := newTestServer(t, c)
 	for i := 0; i < 65; i++ { // > 60
 		response := request(t, s, "PUT", "/mytopic", fmt.Sprintf("message %d", i), nil)

+ 1 - 1
server/smtp_sender.go

@@ -32,7 +32,7 @@ func (s *smtpSender) Send(v *visitor, m *message, to string) error {
 		if err != nil {
 			return err
 		}
-		message, err := formatMail(s.config.BaseURL, v.ip, s.config.SMTPSenderFrom, to, m)
+		message, err := formatMail(s.config.BaseURL, v.ip.String(), s.config.SMTPSenderFrom, to, m)
 		if err != nil {
 			return err
 		}

+ 4 - 2
server/types.go

@@ -1,9 +1,11 @@
 package server
 
 import (
-	"heckel.io/ntfy/util"
 	"net/http"
+	"net/netip"
 	"time"
+
+	"heckel.io/ntfy/util"
 )
 
 // List of possible events
@@ -33,7 +35,7 @@ type message struct {
 	Actions    []*action   `json:"actions,omitempty"`
 	Attachment *attachment `json:"attachment,omitempty"`
 	PollID     string      `json:"poll_id,omitempty"`
-	Sender     string      `json:"-"`                  // IP address of uploader, used for rate limiting
+	Sender     netip.Addr  `json:"-"`                  // IP address of uploader, used for rate limiting
 	Encoding   string      `json:"encoding,omitempty"` // empty for raw UTF-8, or "base64" for encoded bytes
 }
 

+ 7 - 5
server/visitor.go

@@ -2,10 +2,12 @@ package server
 
 import (
 	"errors"
-	"golang.org/x/time/rate"
-	"heckel.io/ntfy/util"
+	"net/netip"
 	"sync"
 	"time"
+
+	"golang.org/x/time/rate"
+	"heckel.io/ntfy/util"
 )
 
 const (
@@ -23,7 +25,7 @@ var (
 type visitor struct {
 	config        *Config
 	messageCache  *messageCache
-	ip            string
+	ip            netip.Addr
 	requests      *rate.Limiter
 	emails        *rate.Limiter
 	subscriptions util.Limiter
@@ -40,7 +42,7 @@ type visitorStats struct {
 	VisitorAttachmentBytesRemaining int64 `json:"visitorAttachmentBytesRemaining"`
 }
 
-func newVisitor(conf *Config, messageCache *messageCache, ip string) *visitor {
+func newVisitor(conf *Config, messageCache *messageCache, ip netip.Addr) *visitor {
 	return &visitor{
 		config:        conf,
 		messageCache:  messageCache,
@@ -115,7 +117,7 @@ func (v *visitor) Stale() bool {
 }
 
 func (v *visitor) Stats() (*visitorStats, error) {
-	attachmentsBytesUsed, err := v.messageCache.AttachmentBytesUsed(v.ip)
+	attachmentsBytesUsed, err := v.messageCache.AttachmentBytesUsed(v.ip.String())
 	if err != nil {
 		return nil, err
 	}

+ 13 - 2
util/util.go

@@ -5,8 +5,6 @@ import (
 	"encoding/json"
 	"errors"
 	"fmt"
-	"github.com/gabriel-vasile/mimetype"
-	"golang.org/x/term"
 	"io"
 	"math/rand"
 	"os"
@@ -15,6 +13,9 @@ import (
 	"strings"
 	"sync"
 	"time"
+
+	"github.com/gabriel-vasile/mimetype"
+	"golang.org/x/term"
 )
 
 const (
@@ -45,6 +46,16 @@ func Contains[T comparable](haystack []T, needle T) bool {
 	return false
 }
 
+// ContainsContains returns true if any element of haystack .Contains(needle).
+func ContainsContains[T interface{ Contains(U) bool }, U any](haystack []T, needle U) bool {
+	for _, s := range haystack {
+		if s.Contains(needle) {
+			return true
+		}
+	}
+	return false
+}
+
 // ContainsAll returns true if all needles are contained in haystack
 func ContainsAll[T comparable](haystack []T, needles []T) bool {
 	matches := 0