Explorar o código

Rate limits make sense now!

binwiederhier %!s(int64=3) %!d(string=hai) anos
pai
achega
c874a641df

+ 3 - 0
cmd/serve.go

@@ -77,6 +77,7 @@ var flagsServe = append(
 	altsrc.NewIntFlag(&cli.IntFlag{Name: "visitor-request-limit-burst", Aliases: []string{"visitor_request_limit_burst"}, EnvVars: []string{"NTFY_VISITOR_REQUEST_LIMIT_BURST"}, Value: server.DefaultVisitorRequestLimitBurst, Usage: "initial limit of requests per visitor"}),
 	altsrc.NewDurationFlag(&cli.DurationFlag{Name: "visitor-request-limit-replenish", Aliases: []string{"visitor_request_limit_replenish"}, EnvVars: []string{"NTFY_VISITOR_REQUEST_LIMIT_REPLENISH"}, Value: server.DefaultVisitorRequestLimitReplenish, Usage: "interval at which burst limit is replenished (one per x)"}),
 	altsrc.NewStringFlag(&cli.StringFlag{Name: "visitor-request-limit-exempt-hosts", Aliases: []string{"visitor_request_limit_exempt_hosts"}, EnvVars: []string{"NTFY_VISITOR_REQUEST_LIMIT_EXEMPT_HOSTS"}, Value: "", Usage: "hostnames and/or IP addresses of hosts that will be exempt from the visitor request limit"}),
+	altsrc.NewIntFlag(&cli.IntFlag{Name: "visitor-message-daily-limit", Aliases: []string{"visitor_message_daily_limit"}, EnvVars: []string{"NTFY_VISITOR_MESSAGE_DAILY_LIMIT"}, Value: server.DefaultVisitorMessageDailyLimit, Usage: "max messages per visitor per day, derived from request limit if unset"}),
 	altsrc.NewIntFlag(&cli.IntFlag{Name: "visitor-email-limit-burst", Aliases: []string{"visitor_email_limit_burst"}, EnvVars: []string{"NTFY_VISITOR_EMAIL_LIMIT_BURST"}, Value: server.DefaultVisitorEmailLimitBurst, Usage: "initial limit of e-mails per visitor"}),
 	altsrc.NewDurationFlag(&cli.DurationFlag{Name: "visitor-email-limit-replenish", Aliases: []string{"visitor_email_limit_replenish"}, EnvVars: []string{"NTFY_VISITOR_EMAIL_LIMIT_REPLENISH"}, Value: server.DefaultVisitorEmailLimitReplenish, Usage: "interval at which burst limit is replenished (one per x)"}),
 	altsrc.NewBoolFlag(&cli.BoolFlag{Name: "behind-proxy", Aliases: []string{"behind_proxy", "P"}, EnvVars: []string{"NTFY_BEHIND_PROXY"}, Value: false, Usage: "if set, use X-Forwarded-For header to determine visitor IP address (for rate limiting)"}),
@@ -150,6 +151,7 @@ func execServe(c *cli.Context) error {
 	visitorRequestLimitBurst := c.Int("visitor-request-limit-burst")
 	visitorRequestLimitReplenish := c.Duration("visitor-request-limit-replenish")
 	visitorRequestLimitExemptHosts := util.SplitNoEmpty(c.String("visitor-request-limit-exempt-hosts"), ",")
+	visitorMessageDailyLimit := c.Int("visitor-message-daily-limit")
 	visitorEmailLimitBurst := c.Int("visitor-email-limit-burst")
 	visitorEmailLimitReplenish := c.Duration("visitor-email-limit-replenish")
 	behindProxy := c.Bool("behind-proxy")
@@ -289,6 +291,7 @@ func execServe(c *cli.Context) error {
 	conf.VisitorRequestLimitBurst = visitorRequestLimitBurst
 	conf.VisitorRequestLimitReplenish = visitorRequestLimitReplenish
 	conf.VisitorRequestExemptIPAddrs = visitorRequestLimitExemptIPs
+	conf.VisitorMessageDailyLimit = visitorMessageDailyLimit
 	conf.VisitorEmailLimitBurst = visitorEmailLimitBurst
 	conf.VisitorEmailLimitReplenish = visitorEmailLimitReplenish
 	conf.BehindProxy = behindProxy

+ 3 - 0
server/config.go

@@ -44,6 +44,7 @@ const (
 	DefaultVisitorSubscriptionLimit             = 30
 	DefaultVisitorRequestLimitBurst             = 60
 	DefaultVisitorRequestLimitReplenish         = 5 * time.Second
+	DefaultVisitorMessageDailyLimit             = 0
 	DefaultVisitorEmailLimitBurst               = 16
 	DefaultVisitorEmailLimitReplenish           = time.Hour
 	DefaultVisitorAccountCreationLimitBurst     = 3
@@ -105,6 +106,7 @@ type Config struct {
 	VisitorRequestLimitBurst             int
 	VisitorRequestLimitReplenish         time.Duration
 	VisitorRequestExemptIPAddrs          []netip.Prefix
+	VisitorMessageDailyLimit             int
 	VisitorEmailLimitBurst               int
 	VisitorEmailLimitReplenish           time.Duration
 	VisitorAccountCreationLimitBurst     int
@@ -171,6 +173,7 @@ func NewConfig() *Config {
 		VisitorRequestLimitBurst:             DefaultVisitorRequestLimitBurst,
 		VisitorRequestLimitReplenish:         DefaultVisitorRequestLimitReplenish,
 		VisitorRequestExemptIPAddrs:          make([]netip.Prefix, 0),
+		VisitorMessageDailyLimit:             DefaultVisitorMessageDailyLimit,
 		VisitorEmailLimitBurst:               DefaultVisitorEmailLimitBurst,
 		VisitorEmailLimitReplenish:           DefaultVisitorEmailLimitReplenish,
 		VisitorAccountCreationLimitBurst:     DefaultVisitorAccountCreationLimitBurst,

+ 2 - 2
server/errors.go

@@ -75,10 +75,10 @@ var (
 	errHTTPTooManyRequestsLimitEmails                = &errHTTP{42902, http.StatusTooManyRequests, "limit reached: too many emails, please be nice", "https://ntfy.sh/docs/publish/#limitations"}
 	errHTTPTooManyRequestsLimitSubscriptions         = &errHTTP{42903, http.StatusTooManyRequests, "limit reached: too many active subscriptions, please be nice", "https://ntfy.sh/docs/publish/#limitations"}
 	errHTTPTooManyRequestsLimitTotalTopics           = &errHTTP{42904, http.StatusTooManyRequests, "limit reached: the total number of topics on the server has been reached, please contact the admin", "https://ntfy.sh/docs/publish/#limitations"}
-	errHTTPTooManyRequestsLimitAttachmentBandwidth   = &errHTTP{42905, http.StatusTooManyRequests, "limit reached: daily bandwidth", "https://ntfy.sh/docs/publish/#limitations"}
+	errHTTPTooManyRequestsLimitAttachmentBandwidth   = &errHTTP{42905, http.StatusTooManyRequests, "limit reached: daily bandwidth reached", "https://ntfy.sh/docs/publish/#limitations"}
 	errHTTPTooManyRequestsLimitAccountCreation       = &errHTTP{42906, http.StatusTooManyRequests, "limit reached: too many accounts created", "https://ntfy.sh/docs/publish/#limitations"} // FIXME document limit
 	errHTTPTooManyRequestsLimitReservations          = &errHTTP{42907, http.StatusTooManyRequests, "limit reached: too many topic reservations for this user", ""}
-	errHTTPTooManyRequestsLimitMessages              = &errHTTP{42908, http.StatusTooManyRequests, "limit reached: too many messages", "https://ntfy.sh/docs/publish/#limitations"}
+	errHTTPTooManyRequestsLimitMessages              = &errHTTP{42908, http.StatusTooManyRequests, "limit reached: daily message quota reached", "https://ntfy.sh/docs/publish/#limitations"}
 	errHTTPInternalError                             = &errHTTP{50001, http.StatusInternalServerError, "internal server error", ""}
 	errHTTPInternalErrorInvalidPath                  = &errHTTP{50002, http.StatusInternalServerError, "internal server error: invalid path", ""}
 	errHTTPInternalErrorMissingBaseURL               = &errHTTP{50003, http.StatusInternalServerError, "internal server error: base-url must be be configured for this feature", "https://ntfy.sh/docs/config/"}

+ 18 - 31
server/server.go

@@ -38,10 +38,9 @@ import (
 TODO
 --
 
-- HIGH Rate limiting: dailyLimitToRate is wrong? + TESTS
 - HIGH Rate limiting: Sensitive endpoints (account/login/change-password/...)
-- HIGH Rate limiting: Delete visitor when tier is changed to refresh rate limiters
 - HIGH Rate limiting: When ResetStats() is run, reset messagesLimiter (and others)?
+- MEDIUM Rate limiting: Test daily message quota read from database initially
 - MEDIUM: Races with v.user (see publishSyncEventAsync test)
 - MEDIUM: Reservation (UI): Show "This topic is reserved" error message when trying to reserve a reserved topic (Thorben)
 - MEDIUM: Reservation (UI): Ask for confirmation when removing reservation (deadcade)
@@ -57,7 +56,6 @@ Make sure account endpoints make sense for admins
 
 Tests:
 - Payment endpoints (make mocks)
-- test that the visitor is based on the IP address when a user has no tier
 */
 
 // Server is the main server, providing the UI and API for ntfy
@@ -308,7 +306,7 @@ func (s *Server) Stop() {
 }
 
 func (s *Server) handle(w http.ResponseWriter, r *http.Request) {
-	v, err := s.visitor(r) // Note: Always returns v, even when error is returned
+	v, err := s.maybeAuthenticate(r) // Note: Always returns v, even when error is returned
 	if err == nil {
 		log.Debug("%s Dispatching request", logHTTPPrefix(v, r))
 		if log.IsTrace() {
@@ -563,7 +561,7 @@ func (s *Server) handlePublishWithoutResponse(r *http.Request, v *visitor) (*mes
 	if v.user != nil {
 		m.User = v.user.ID
 	}
-	m.Expires = time.Now().Add(v.Limits().MessagesExpiryDuration).Unix()
+	m.Expires = time.Now().Add(v.Limits().MessageExpiryDuration).Unix()
 	if err := s.handlePublishBody(r, v, m, body, unifiedpush); err != nil {
 		return nil, err
 	}
@@ -601,7 +599,7 @@ func (s *Server) handlePublishWithoutResponse(r *http.Request, v *visitor) (*mes
 	}
 	v.IncrementMessages()
 	if s.userManager != nil && v.user != nil {
-		s.userManager.EnqueueStats(v.user)
+		s.userManager.EnqueueStats(v.user) // FIXME this makes no sense for tier-less users
 	}
 	s.mu.Lock()
 	s.messages++
@@ -1382,8 +1380,10 @@ func (s *Server) runStatsResetter() {
 		log.Debug("Stats resetter: Waiting until %v to reset visitor stats", runAt)
 		select {
 		case <-timer.C:
+			log.Debug("Stats resetter: Running")
 			s.resetStats()
 		case <-s.closeChan:
+			log.Debug("Stats resetter: Stopping timer")
 			timer.Stop()
 			return
 		}
@@ -1440,17 +1440,15 @@ func (s *Server) sendDelayedMessages() error {
 		return err
 	}
 	for _, m := range messages {
-		var v *visitor
+		var u *user.User
 		if s.userManager != nil && m.User != "" {
-			u, err := s.userManager.User(m.User)
+			u, err = s.userManager.User(m.User)
 			if err != nil {
-				log.Warn("%s Error sending delayed message: %s", logMessagePrefix(v, m), err.Error())
+				log.Warn("Error sending delayed message %s: %s", m.ID, err.Error())
 				continue
 			}
-			v = s.visitorFromUser(u, m.Sender)
-		} else {
-			v = s.visitorFromIP(m.Sender)
 		}
+		v := s.visitor(m.Sender, u)
 		if err := s.sendDelayedMessage(v, m); err != nil {
 			log.Warn("%s Error sending delayed message: %s", logMessagePrefix(v, m), err.Error())
 		}
@@ -1588,20 +1586,16 @@ func (s *Server) autorizeTopic(next handleFunc, perm user.Permission) handleFunc
 	}
 }
 
-// visitor creates or retrieves a rate.Limiter for the given visitor.
+// maybeAuthenticate creates or retrieves a rate.Limiter for the given visitor.
 // Note that this function will always return a visitor, even if an error occurs.
-func (s *Server) visitor(r *http.Request) (v *visitor, err error) {
+func (s *Server) maybeAuthenticate(r *http.Request) (v *visitor, err error) {
 	ip := extractIPAddress(r, s.config.BehindProxy)
 	var u *user.User // may stay nil if no auth header!
 	if u, err = s.authenticate(r); err != nil {
 		log.Debug("authentication failed: %s", err.Error())
 		err = errHTTPUnauthorized // Always return visitor, even when error occurs!
 	}
-	if u != nil {
-		v = s.visitorFromUser(u, ip)
-	} else {
-		v = s.visitorFromIP(ip)
-	}
+	v = s.visitor(ip, u)
 	v.SetUser(u)  // Update visitor user with latest from database!
 	return v, err // Always return visitor, even when error occurs!
 }
@@ -1645,26 +1639,19 @@ func (s *Server) authenticateBearerAuth(value string) (user *user.User, err erro
 	return s.userManager.AuthenticateToken(token)
 }
 
-func (s *Server) visitorFromID(visitorID string, ip netip.Addr, user *user.User) *visitor {
+func (s *Server) visitor(ip netip.Addr, user *user.User) *visitor {
 	s.mu.Lock()
 	defer s.mu.Unlock()
-	v, exists := s.visitors[visitorID]
+	id := visitorID(ip, user)
+	v, exists := s.visitors[id]
 	if !exists {
-		s.visitors[visitorID] = newVisitor(s.config, s.messageCache, s.userManager, ip, user)
-		return s.visitors[visitorID]
+		s.visitors[id] = newVisitor(s.config, s.messageCache, s.userManager, ip, user)
+		return s.visitors[id]
 	}
 	v.Keepalive()
 	return v
 }
 
-func (s *Server) visitorFromIP(ip netip.Addr) *visitor {
-	return s.visitorFromID(fmt.Sprintf("ip:%s", ip.String()), ip, nil)
-}
-
-func (s *Server) visitorFromUser(user *user.User, ip netip.Addr) *visitor {
-	return s.visitorFromID(fmt.Sprintf("user:%s", user.ID), ip, user)
-}
-
 func (s *Server) writeJSON(w http.ResponseWriter, v any) error {
 	w.Header().Set("Content-Type", "application/json")
 	w.Header().Set("Access-Control-Allow-Origin", s.config.AccessControlAllowOrigin) // CORS, allow cross-origin requests

+ 6 - 0
server/server.yml

@@ -200,6 +200,12 @@
 # visitor-request-limit-replenish: "5s"
 # visitor-request-limit-exempt-hosts: ""
 
+# Rate limiting: Hard daily limit of messages per visitor and day. The limit is reset
+# every day at midnight UTC. If the limit is not set (or set to zero), the request
+# limit (see above) governs the upper limit.
+#
+# visitor-message-daily-limit: 0
+
 # Rate limiting: Allowed emails per visitor:
 # - visitor-email-limit-burst is the initial bucket of emails each visitor has
 # - visitor-email-limit-replenish is the rate at which the bucket is refilled

+ 7 - 7
server/server_account.go

@@ -23,6 +23,9 @@ func (s *Server) handleAccountCreate(w http.ResponseWriter, r *http.Request, v *
 		} else if v.user != nil {
 			return errHTTPUnauthorized // Cannot create account from user context
 		}
+		if err := v.AccountCreationAllowed(); err != nil {
+			return errHTTPTooManyRequestsLimitAccountCreation
+		}
 	}
 	newAccount, err := readJSONWithLimit[apiAccountCreateRequest](r.Body, jsonBodyBytesLimit)
 	if err != nil {
@@ -31,9 +34,6 @@ func (s *Server) handleAccountCreate(w http.ResponseWriter, r *http.Request, v *
 	if existingUser, _ := s.userManager.User(newAccount.Username); existingUser != nil {
 		return errHTTPConflictUserExists
 	}
-	if err := v.AccountCreationAllowed(); err != nil {
-		return errHTTPTooManyRequestsLimitAccountCreation
-	}
 	if err := s.userManager.AddUser(newAccount.Username, newAccount.Password, user.RoleUser); err != nil { // TODO this should return a User
 		return err
 	}
@@ -49,9 +49,9 @@ func (s *Server) handleAccountGet(w http.ResponseWriter, _ *http.Request, v *vis
 	response := &apiAccountResponse{
 		Limits: &apiAccountLimits{
 			Basis:                    string(limits.Basis),
-			Messages:                 limits.MessagesLimit,
-			MessagesExpiryDuration:   int64(limits.MessagesExpiryDuration.Seconds()),
-			Emails:                   limits.EmailsLimit,
+			Messages:                 limits.MessageLimit,
+			MessagesExpiryDuration:   int64(limits.MessageExpiryDuration.Seconds()),
+			Emails:                   limits.EmailLimit,
 			Reservations:             limits.ReservationsLimit,
 			AttachmentTotalSize:      limits.AttachmentTotalSizeLimit,
 			AttachmentFileSize:       limits.AttachmentFileSizeLimit,
@@ -344,7 +344,7 @@ func (s *Server) handleAccountReservationAdd(w http.ResponseWriter, r *http.Requ
 		reservations, err := s.userManager.ReservationsCount(v.user.Name)
 		if err != nil {
 			return err
-		} else if reservations >= v.user.Tier.ReservationsLimit {
+		} else if reservations >= v.user.Tier.ReservationLimit {
 			return errHTTPTooManyRequestsLimitReservations
 		}
 	}

+ 18 - 18
server/server_account_test.go

@@ -410,10 +410,10 @@ func TestAccount_Reservation_AddRemoveUserWithTierSuccess(t *testing.T) {
 	// Create a tier
 	require.Nil(t, s.userManager.CreateTier(&user.Tier{
 		Code:                     "pro",
-		MessagesLimit:            123,
-		MessagesExpiryDuration:   86400 * time.Second,
-		EmailsLimit:              32,
-		ReservationsLimit:        2,
+		MessageLimit:             123,
+		MessageExpiryDuration:    86400 * time.Second,
+		EmailLimit:               32,
+		ReservationLimit:         2,
 		AttachmentFileSizeLimit:  1231231,
 		AttachmentTotalSizeLimit: 123123,
 		AttachmentExpiryDuration: 10800 * time.Second,
@@ -491,9 +491,9 @@ func TestAccount_Reservation_PublishByAnonymousFails(t *testing.T) {
 	require.Equal(t, 200, rr.Code)
 
 	require.Nil(t, s.userManager.CreateTier(&user.Tier{
-		Code:              "pro",
-		MessagesLimit:     20,
-		ReservationsLimit: 2,
+		Code:             "pro",
+		MessageLimit:     20,
+		ReservationLimit: 2,
 	}))
 	require.Nil(t, s.userManager.ChangeTier("phil", "pro"))
 
@@ -525,9 +525,9 @@ func TestAccount_Reservation_Add_Kills_Other_Subscribers(t *testing.T) {
 	require.Equal(t, 200, rr.Code)
 
 	require.Nil(t, s.userManager.CreateTier(&user.Tier{
-		Code:              "pro",
-		MessagesLimit:     20,
-		ReservationsLimit: 2,
+		Code:             "pro",
+		MessageLimit:     20,
+		ReservationLimit: 2,
 	}))
 	require.Nil(t, s.userManager.ChangeTier("phil", "pro"))
 
@@ -591,10 +591,10 @@ func TestAccount_Tier_Create(t *testing.T) {
 	require.Nil(t, s.userManager.CreateTier(&user.Tier{
 		Code:                     "pro",
 		Name:                     "Pro",
-		MessagesLimit:            123,
-		MessagesExpiryDuration:   86400 * time.Second,
-		EmailsLimit:              32,
-		ReservationsLimit:        2,
+		MessageLimit:             123,
+		MessageExpiryDuration:    86400 * time.Second,
+		EmailLimit:               32,
+		ReservationLimit:         2,
 		AttachmentFileSizeLimit:  1231231,
 		AttachmentTotalSizeLimit: 123123,
 		AttachmentExpiryDuration: 10800 * time.Second,
@@ -616,10 +616,10 @@ func TestAccount_Tier_Create(t *testing.T) {
 	require.True(t, strings.HasPrefix(ti.ID, "ti_"))
 	require.Equal(t, "pro", ti.Code)
 	require.Equal(t, "Pro", ti.Name)
-	require.Equal(t, int64(123), ti.MessagesLimit)
-	require.Equal(t, 86400*time.Second, ti.MessagesExpiryDuration)
-	require.Equal(t, int64(32), ti.EmailsLimit)
-	require.Equal(t, int64(2), ti.ReservationsLimit)
+	require.Equal(t, int64(123), ti.MessageLimit)
+	require.Equal(t, 86400*time.Second, ti.MessageExpiryDuration)
+	require.Equal(t, int64(32), ti.EmailLimit)
+	require.Equal(t, int64(2), ti.ReservationLimit)
 	require.Equal(t, int64(1231231), ti.AttachmentFileSizeLimit)
 	require.Equal(t, int64(123123), ti.AttachmentTotalSizeLimit)
 	require.Equal(t, 10800*time.Second, ti.AttachmentExpiryDuration)

+ 11 - 11
server/server_payments.go

@@ -60,15 +60,15 @@ func (s *Server) handleBillingTiersGet(w http.ResponseWriter, _ *http.Request, _
 	if err != nil {
 		return err
 	}
-	freeTier := defaultVisitorLimits(s.config)
+	freeTier := configBasedVisitorLimits(s.config)
 	response := []*apiAccountBillingTier{
 		{
 			// This is a bit of a hack: This is the "Free" tier. It has no tier code, name or price.
 			Limits: &apiAccountLimits{
 				Basis:                    string(visitorLimitBasisIP),
-				Messages:                 freeTier.MessagesLimit,
-				MessagesExpiryDuration:   int64(freeTier.MessagesExpiryDuration.Seconds()),
-				Emails:                   freeTier.EmailsLimit,
+				Messages:                 freeTier.MessageLimit,
+				MessagesExpiryDuration:   int64(freeTier.MessageExpiryDuration.Seconds()),
+				Emails:                   freeTier.EmailLimit,
 				Reservations:             freeTier.ReservationsLimit,
 				AttachmentTotalSize:      freeTier.AttachmentTotalSizeLimit,
 				AttachmentFileSize:       freeTier.AttachmentFileSizeLimit,
@@ -91,10 +91,10 @@ func (s *Server) handleBillingTiersGet(w http.ResponseWriter, _ *http.Request, _
 			Price: priceStr,
 			Limits: &apiAccountLimits{
 				Basis:                    string(visitorLimitBasisTier),
-				Messages:                 tier.MessagesLimit,
-				MessagesExpiryDuration:   int64(tier.MessagesExpiryDuration.Seconds()),
-				Emails:                   tier.EmailsLimit,
-				Reservations:             tier.ReservationsLimit,
+				Messages:                 tier.MessageLimit,
+				MessagesExpiryDuration:   int64(tier.MessageExpiryDuration.Seconds()),
+				Emails:                   tier.EmailLimit,
+				Reservations:             tier.ReservationLimit,
 				AttachmentTotalSize:      tier.AttachmentTotalSizeLimit,
 				AttachmentFileSize:       tier.AttachmentFileSizeLimit,
 				AttachmentExpiryDuration: int64(tier.AttachmentExpiryDuration.Seconds()),
@@ -336,7 +336,7 @@ func (s *Server) handleAccountBillingWebhookSubscriptionUpdated(event json.RawMe
 	if err := s.updateSubscriptionAndTier(logStripePrefix(ev.Customer, ev.ID), u, tier, ev.Customer, subscriptionID, ev.Status, ev.CurrentPeriodEnd, ev.CancelAt); err != nil {
 		return err
 	}
-	s.publishSyncEventAsync(s.visitorFromUser(u, netip.IPv4Unspecified()))
+	s.publishSyncEventAsync(s.visitor(netip.IPv4Unspecified(), u))
 	return nil
 }
 
@@ -355,14 +355,14 @@ func (s *Server) handleAccountBillingWebhookSubscriptionDeleted(event json.RawMe
 	if err := s.updateSubscriptionAndTier(logStripePrefix(ev.Customer, ev.ID), u, nil, ev.Customer, "", "", 0, 0); err != nil {
 		return err
 	}
-	s.publishSyncEventAsync(s.visitorFromUser(u, netip.IPv4Unspecified()))
+	s.publishSyncEventAsync(s.visitor(netip.IPv4Unspecified(), u))
 	return nil
 }
 
 func (s *Server) updateSubscriptionAndTier(logPrefix string, u *user.User, tier *user.Tier, customerID, subscriptionID, status string, paidUntil, cancelAt int64) error {
 	reservationsLimit := visitorDefaultReservationsLimit
 	if tier != nil {
-		reservationsLimit = tier.ReservationsLimit
+		reservationsLimit = tier.ReservationLimit
 	}
 	if err := s.maybeRemoveMessagesAndExcessReservations(logPrefix, u, reservationsLimit); err != nil {
 		return err

+ 73 - 28
server/server_payments_test.go

@@ -5,11 +5,14 @@ import (
 	"github.com/stretchr/testify/mock"
 	"github.com/stretchr/testify/require"
 	"github.com/stripe/stripe-go/v74"
+	"golang.org/x/time/rate"
 	"heckel.io/ntfy/user"
 	"heckel.io/ntfy/util"
 	"io"
+	"net/netip"
 	"path/filepath"
 	"strings"
+	"sync"
 	"testing"
 	"time"
 )
@@ -48,10 +51,10 @@ func TestPayments_Tiers(t *testing.T) {
 		ID:                       "ti_123",
 		Code:                     "pro",
 		Name:                     "Pro",
-		MessagesLimit:            1000,
-		MessagesExpiryDuration:   time.Hour,
-		EmailsLimit:              123,
-		ReservationsLimit:        777,
+		MessageLimit:             1000,
+		MessageExpiryDuration:    time.Hour,
+		EmailLimit:               123,
+		ReservationLimit:         777,
 		AttachmentFileSizeLimit:  999,
 		AttachmentTotalSizeLimit: 888,
 		AttachmentExpiryDuration: time.Minute,
@@ -61,10 +64,10 @@ func TestPayments_Tiers(t *testing.T) {
 		ID:                       "ti_444",
 		Code:                     "business",
 		Name:                     "Business",
-		MessagesLimit:            2000,
-		MessagesExpiryDuration:   10 * time.Hour,
-		EmailsLimit:              123123,
-		ReservationsLimit:        777333,
+		MessageLimit:             2000,
+		MessageExpiryDuration:    10 * time.Hour,
+		EmailLimit:               123123,
+		ReservationLimit:         777333,
 		AttachmentFileSizeLimit:  999111,
 		AttachmentTotalSizeLimit: 888111,
 		AttachmentExpiryDuration: time.Hour,
@@ -238,9 +241,14 @@ func TestPayments_AccountDelete_Cancels_Subscription(t *testing.T) {
 	require.Equal(t, 401, rr.Code)
 }
 
-func TestPayments_Checkout_Success_And_Increase_Ratelimits_Reset_Visitor(t *testing.T) {
-	// This tests a successful checkout flow (not a paying customer -> paying customer),
-	// and also tests that during the upgrade we are RESETTING THE RATE LIMITS of the existing user.
+func TestPayments_Checkout_Success_And_Increase_Rate_Limits_Reset_Visitor(t *testing.T) {
+	// This test is too overloaded, but it's also a great end-to-end a test.
+	//
+	// It tests:
+	// - A successful checkout flow (not a paying customer -> paying customer)
+	// - Tier-changes reset the rate limits for the user
+	// - The request limits for tier-less user and a tier-user
+	// - The message limits for a tier-user
 
 	stripeMock := &testStripeAPI{}
 	defer stripeMock.AssertExpectations(t)
@@ -248,19 +256,26 @@ func TestPayments_Checkout_Success_And_Increase_Ratelimits_Reset_Visitor(t *test
 	c := newTestConfigWithAuthFile(t)
 	c.StripeSecretKey = "secret key"
 	c.StripeWebhookKey = "webhook key"
-	c.VisitorRequestLimitBurst = 10
+	c.VisitorRequestLimitBurst = 5
 	c.VisitorRequestLimitReplenish = time.Hour
+	c.CacheStartupQueries = `
+pragma journal_mode = WAL;
+pragma synchronous = normal;
+pragma temp_store = memory;
+`
+	c.CacheBatchSize = 500
+	c.CacheBatchTimeout = time.Second
 	s := newTestServer(t, c)
 	s.stripe = stripeMock
 
 	// Create a user with a Stripe subscription and 3 reservations
 	require.Nil(t, s.userManager.CreateTier(&user.Tier{
-		ID:                     "ti_123",
-		Code:                   "starter",
-		StripePriceID:          "price_1234",
-		ReservationsLimit:      1,
-		MessagesLimit:          100,
-		MessagesExpiryDuration: time.Hour,
+		ID:                    "ti_123",
+		Code:                  "starter",
+		StripePriceID:         "price_1234",
+		ReservationLimit:      1,
+		MessageLimit:          220, // 220 * 5% = 11 requests before rate limiting kicks in
+		MessageExpiryDuration: time.Hour,
 	}))
 	require.Nil(t, s.userManager.AddUser("phil", "phil", user.RoleUser)) // No tier
 	u, err := s.userManager.User("phil")
@@ -298,7 +313,7 @@ func TestPayments_Checkout_Success_And_Increase_Ratelimits_Reset_Visitor(t *test
 		Return(&stripe.Customer{}, nil)
 
 	// Send messages until rate limit of free tier is hit
-	for i := 0; i < 10; i++ {
+	for i := 0; i < 5; i++ {
 		rr := request(t, s, "PUT", "/mytopic", "some message", map[string]string{
 			"Authorization": util.BasicAuth("phil", "phil"),
 		})
@@ -323,10 +338,9 @@ func TestPayments_Checkout_Success_And_Increase_Ratelimits_Reset_Visitor(t *test
 	require.Equal(t, int64(123456789), u.Billing.StripeSubscriptionPaidUntil.Unix())
 	require.Equal(t, int64(0), u.Billing.StripeSubscriptionCancelAt.Unix())
 
-	// FIXME FIXME This test is broken, because the rate limit logic is unclear!
-
 	// Now for the fun part: Verify that new rate limits are immediately applied
-	for i := 0; i < 100; i++ {
+	// This only tests the request limiter, which kicks in before the message limiter.
+	for i := 0; i < 11; i++ {
 		rr := request(t, s, "PUT", "/mytopic", "some message", map[string]string{
 			"Authorization": util.BasicAuth("phil", "phil"),
 		})
@@ -336,6 +350,37 @@ func TestPayments_Checkout_Success_And_Increase_Ratelimits_Reset_Visitor(t *test
 		"Authorization": util.BasicAuth("phil", "phil"),
 	})
 	require.Equal(t, 429, rr.Code)
+
+	// Now let's test the message limiter by faking a ridiculously generous rate limiter
+	v := s.visitor(netip.MustParseAddr("9.9.9.9"), u)
+	v.requestLimiter = rate.NewLimiter(rate.Every(time.Millisecond), 1000000)
+
+	var wg sync.WaitGroup
+	for i := 0; i < 209; i++ {
+		wg.Add(1)
+		go func() {
+			rr := request(t, s, "PUT", "/mytopic", "some message", map[string]string{
+				"Authorization": util.BasicAuth("phil", "phil"),
+			})
+			require.Equal(t, 200, rr.Code)
+			wg.Done()
+		}()
+	}
+	wg.Wait()
+	rr = request(t, s, "PUT", "/mytopic", "some message", map[string]string{
+		"Authorization": util.BasicAuth("phil", "phil"),
+	})
+	require.Equal(t, 429, rr.Code)
+
+	// And now let's cross-check that the stats are correct too
+	rr = request(t, s, "GET", "/v1/account", "", map[string]string{
+		"Authorization": util.BasicAuth("phil", "phil"),
+	})
+	require.Equal(t, 200, rr.Code)
+	account, _ := util.UnmarshalJSON[apiAccountResponse](io.NopCloser(rr.Body))
+	require.Equal(t, int64(220), account.Limits.Messages)
+	require.Equal(t, int64(220), account.Stats.Messages)
+	require.Equal(t, int64(0), account.Stats.MessagesRemaining)
 }
 
 func TestPayments_Webhook_Subscription_Updated_Downgrade_From_PastDue_To_Active(t *testing.T) {
@@ -363,9 +408,9 @@ func TestPayments_Webhook_Subscription_Updated_Downgrade_From_PastDue_To_Active(
 		ID:                       "ti_1",
 		Code:                     "starter",
 		StripePriceID:            "price_1234", // !
-		ReservationsLimit:        1,            // !
-		MessagesLimit:            100,
-		MessagesExpiryDuration:   time.Hour,
+		ReservationLimit:         1,            // !
+		MessageLimit:             100,
+		MessageExpiryDuration:    time.Hour,
 		AttachmentExpiryDuration: time.Hour,
 		AttachmentFileSizeLimit:  1000000,
 		AttachmentTotalSizeLimit: 1000000,
@@ -375,9 +420,9 @@ func TestPayments_Webhook_Subscription_Updated_Downgrade_From_PastDue_To_Active(
 		ID:                       "ti_2",
 		Code:                     "pro",
 		StripePriceID:            "price_1111", // !
-		ReservationsLimit:        3,            // !
-		MessagesLimit:            200,
-		MessagesExpiryDuration:   time.Hour,
+		ReservationLimit:         3,            // !
+		MessageLimit:             200,
+		MessageExpiryDuration:    time.Hour,
 		AttachmentExpiryDuration: time.Hour,
 		AttachmentFileSizeLimit:  1000000,
 		AttachmentTotalSizeLimit: 1000000,

+ 82 - 27
server/server_test.go

@@ -8,7 +8,6 @@ import (
 	"fmt"
 	"heckel.io/ntfy/user"
 	"io"
-	"log"
 	"math/rand"
 	"net/http"
 	"net/http/httptest"
@@ -22,9 +21,14 @@ import (
 	"github.com/stretchr/testify/assert"
 
 	"github.com/stretchr/testify/require"
+	"heckel.io/ntfy/log"
 	"heckel.io/ntfy/util"
 )
 
+func init() {
+	// log.SetLevel(log.DebugLevel)
+}
+
 func TestServer_PublishAndPoll(t *testing.T) {
 	s := newTestServer(t, newTestConfig(t))
 
@@ -742,16 +746,31 @@ func TestServer_Auth_ViaQuery(t *testing.T) {
 	require.Equal(t, 401, response.Code)
 }
 
-func TestServer_StatsResetter(t *testing.T) {
+func TestServer_StatsResetter_User_Without_Tier(t *testing.T) {
+	// This tests the stats resetter for
+	// - an anonymous user
+	// - a user without a tier (treated like the same as the anonymous user)
+	// - a user with a tier
+
 	c := newTestConfigWithAuthFile(t)
-	c.AuthDefault = user.PermissionDenyAll
 	c.VisitorStatsResetTime = time.Now().Add(2 * time.Second)
 	s := newTestServer(t, c)
 	go s.runStatsResetter()
 
+	// Create user with tier (tieruser) and user without tier (phil)
+	require.Nil(t, s.userManager.CreateTier(&user.Tier{
+		Code:                  "test",
+		MessageLimit:          5,
+		MessageExpiryDuration: -5 * time.Second, // Second, what a hack!
+	}))
 	require.Nil(t, s.userManager.AddUser("phil", "phil", user.RoleUser))
-	require.Nil(t, s.userManager.AllowAccess("phil", "mytopic", user.PermissionReadWrite))
+	require.Nil(t, s.userManager.AddUser("tieruser", "tieruser", user.RoleUser))
+	require.Nil(t, s.userManager.ChangeTier("tieruser", "test"))
+
+	// Send an anonymous message
+	response := request(t, s, "PUT", "/mytopic", "test", nil)
 
+	// Send messages from user without tier (phil)
 	for i := 0; i < 5; i++ {
 		response := request(t, s, "PUT", "/mytopic", "test", map[string]string{
 			"Authorization": util.BasicAuth("phil", "phil"),
@@ -759,30 +778,66 @@ func TestServer_StatsResetter(t *testing.T) {
 		require.Equal(t, 200, response.Code)
 	}
 
-	response := request(t, s, "GET", "/v1/account", "", map[string]string{
+	// Send messages from user with tier
+	for i := 0; i < 2; i++ {
+		response := request(t, s, "PUT", "/mytopic", "test", map[string]string{
+			"Authorization": util.BasicAuth("tieruser", "tieruser"),
+		})
+		require.Equal(t, 200, response.Code)
+	}
+
+	// User stats show 6 messages (for user without tier)
+	response = request(t, s, "GET", "/v1/account", "", map[string]string{
 		"Authorization": util.BasicAuth("phil", "phil"),
 	})
 	require.Equal(t, 200, response.Code)
+	account, err := util.UnmarshalJSON[apiAccountResponse](io.NopCloser(response.Body))
+	require.Nil(t, err)
+	require.Equal(t, int64(6), account.Stats.Messages)
+
+	// User stats show 6 messages (for anonymous visitor)
+	response = request(t, s, "GET", "/v1/account", "", nil)
+	require.Equal(t, 200, response.Code)
+	account, err = util.UnmarshalJSON[apiAccountResponse](io.NopCloser(response.Body))
+	require.Nil(t, err)
+	require.Equal(t, int64(6), account.Stats.Messages)
 
-	// User stats show 10 messages
+	// User stats show 2 messages (for user with tier)
 	response = request(t, s, "GET", "/v1/account", "", map[string]string{
-		"Authorization": util.BasicAuth("phil", "phil"),
+		"Authorization": util.BasicAuth("tieruser", "tieruser"),
 	})
 	require.Equal(t, 200, response.Code)
-	account, err := util.UnmarshalJSON[apiAccountResponse](io.NopCloser(response.Body))
+	account, err = util.UnmarshalJSON[apiAccountResponse](io.NopCloser(response.Body))
 	require.Nil(t, err)
-	require.Equal(t, int64(5), account.Stats.Messages)
+	require.Equal(t, int64(2), account.Stats.Messages)
 
 	// Wait for stats resetter to run
 	time.Sleep(2200 * time.Millisecond)
 
 	// User stats show 0 messages now!
+	response = request(t, s, "GET", "/v1/account", "", map[string]string{
+		"Authorization": util.BasicAuth("phil", "phil"),
+	})
+	require.Equal(t, 200, response.Code)
+	account, err = util.UnmarshalJSON[apiAccountResponse](io.NopCloser(response.Body))
+	require.Nil(t, err)
+	require.Equal(t, int64(0), account.Stats.Messages)
+
+	// Since this is a user without a tier, the anonymous user should have the same stats
 	response = request(t, s, "GET", "/v1/account", "", nil)
 	require.Equal(t, 200, response.Code)
 	account, err = util.UnmarshalJSON[apiAccountResponse](io.NopCloser(response.Body))
 	require.Nil(t, err)
 	require.Equal(t, int64(0), account.Stats.Messages)
 
+	// User stats show 0 messages (for user with tier)
+	response = request(t, s, "GET", "/v1/account", "", map[string]string{
+		"Authorization": util.BasicAuth("tieruser", "tieruser"),
+	})
+	require.Equal(t, 200, response.Code)
+	account, err = util.UnmarshalJSON[apiAccountResponse](io.NopCloser(response.Body))
+	require.Nil(t, err)
+	require.Equal(t, int64(0), account.Stats.Messages)
 }
 
 type testMailer struct {
@@ -1133,9 +1188,9 @@ func TestServer_PublishWithTierBasedMessageLimitAndExpiry(t *testing.T) {
 
 	// Create tier with certain limits
 	require.Nil(t, s.userManager.CreateTier(&user.Tier{
-		Code:                   "test",
-		MessagesLimit:          5,
-		MessagesExpiryDuration: -5 * time.Second, // Second, what a hack!
+		Code:                  "test",
+		MessageLimit:          5,
+		MessageExpiryDuration: -5 * time.Second, // Second, what a hack!
 	}))
 	require.Nil(t, s.userManager.AddUser("phil", "phil", user.RoleUser))
 	require.Nil(t, s.userManager.ChangeTier("phil", "test"))
@@ -1363,8 +1418,8 @@ func TestServer_PublishAttachmentWithTierBasedExpiry(t *testing.T) {
 	sevenDays := time.Duration(604800) * time.Second
 	require.Nil(t, s.userManager.CreateTier(&user.Tier{
 		Code:                     "test",
-		MessagesLimit:            10,
-		MessagesExpiryDuration:   sevenDays,
+		MessageLimit:             10,
+		MessageExpiryDuration:    sevenDays,
 		AttachmentFileSizeLimit:  50_000,
 		AttachmentTotalSizeLimit: 200_000,
 		AttachmentExpiryDuration: sevenDays, // 7 days
@@ -1407,8 +1462,8 @@ func TestServer_PublishAttachmentWithTierBasedBandwidthLimit(t *testing.T) {
 	// Create tier with certain limits
 	require.Nil(t, s.userManager.CreateTier(&user.Tier{
 		Code:                     "test",
-		MessagesLimit:            10,
-		MessagesExpiryDuration:   time.Hour,
+		MessageLimit:             10,
+		MessageExpiryDuration:    time.Hour,
 		AttachmentFileSizeLimit:  50_000,
 		AttachmentTotalSizeLimit: 200_000,
 		AttachmentExpiryDuration: time.Hour,
@@ -1450,7 +1505,7 @@ func TestServer_PublishAttachmentWithTierBasedLimits(t *testing.T) {
 	// Create tier with certain limits
 	require.Nil(t, s.userManager.CreateTier(&user.Tier{
 		Code:                     "test",
-		MessagesLimit:            100,
+		MessageLimit:             100,
 		AttachmentFileSizeLimit:  50_000,
 		AttachmentTotalSizeLimit: 200_000,
 		AttachmentExpiryDuration: 30 * time.Second,
@@ -1574,7 +1629,7 @@ func TestServer_Visitor_XForwardedFor_None(t *testing.T) {
 	r, _ := http.NewRequest("GET", "/bla", nil)
 	r.RemoteAddr = "8.9.10.11"
 	r.Header.Set("X-Forwarded-For", "  ") // Spaces, not empty!
-	v, err := s.visitor(r)
+	v, err := s.maybeAuthenticate(r)
 	require.Nil(t, err)
 	require.Equal(t, "8.9.10.11", v.ip.String())
 }
@@ -1586,7 +1641,7 @@ func TestServer_Visitor_XForwardedFor_Single(t *testing.T) {
 	r, _ := http.NewRequest("GET", "/bla", nil)
 	r.RemoteAddr = "8.9.10.11"
 	r.Header.Set("X-Forwarded-For", "1.1.1.1")
-	v, err := s.visitor(r)
+	v, err := s.maybeAuthenticate(r)
 	require.Nil(t, err)
 	require.Equal(t, "1.1.1.1", v.ip.String())
 }
@@ -1598,7 +1653,7 @@ func TestServer_Visitor_XForwardedFor_Multiple(t *testing.T) {
 	r, _ := http.NewRequest("GET", "/bla", nil)
 	r.RemoteAddr = "8.9.10.11"
 	r.Header.Set("X-Forwarded-For", "1.2.3.4 , 2.4.4.2,234.5.2.1 ")
-	v, err := s.visitor(r)
+	v, err := s.maybeAuthenticate(r)
 	require.Nil(t, err)
 	require.Equal(t, "234.5.2.1", v.ip.String())
 }
@@ -1611,7 +1666,7 @@ func TestServer_PublishWhileUpdatingStatsWithLotsOfMessages(t *testing.T) {
 	s := newTestServer(t, c)
 
 	// Add lots of messages
-	log.Printf("Adding %d messages", count)
+	log.Info("Adding %d messages", count)
 	start := time.Now()
 	messages := make([]*message, 0)
 	for i := 0; i < count; i++ {
@@ -1621,31 +1676,31 @@ func TestServer_PublishWhileUpdatingStatsWithLotsOfMessages(t *testing.T) {
 		messages = append(messages, newDefaultMessage(topicID, "some message"))
 	}
 	require.Nil(t, s.messageCache.addMessages(messages))
-	log.Printf("Done: Adding %d messages; took %s", count, time.Since(start).Round(time.Millisecond))
+	log.Info("Done: Adding %d messages; took %s", count, time.Since(start).Round(time.Millisecond))
 
 	// Update stats
 	statsChan := make(chan bool)
 	go func() {
-		log.Printf("Updating stats")
+		log.Info("Updating stats")
 		start := time.Now()
 		s.execManager()
-		log.Printf("Done: Updating stats; took %s", time.Since(start).Round(time.Millisecond))
+		log.Info("Done: Updating stats; took %s", time.Since(start).Round(time.Millisecond))
 		statsChan <- true
 	}()
 	time.Sleep(50 * time.Millisecond) // Make sure it starts first
 
 	// Publish message (during stats update)
-	log.Printf("Publishing message")
+	log.Info("Publishing message")
 	start = time.Now()
 	response := request(t, s, "PUT", "/mytopic", "some body", nil)
 	m := toMessage(t, response.Body.String())
 	assert.Equal(t, "some body", m.Message)
 	assert.True(t, time.Since(start) < 100*time.Millisecond)
-	log.Printf("Done: Publishing message; took %s", time.Since(start).Round(time.Millisecond))
+	log.Info("Done: Publishing message; took %s", time.Since(start).Round(time.Millisecond))
 
 	// Wait for all goroutines
 	<-statsChan
-	log.Printf("Done: Waiting for all locks")
+	log.Info("Done: Waiting for all locks")
 }
 
 func newTestConfig(t *testing.T) *Config {

+ 100 - 56
server/visitor.go

@@ -14,16 +14,39 @@ import (
 )
 
 const (
+	// oneDay is an approximation of a day as a time.Duration
+	oneDay = 24 * time.Hour
+
 	// visitorExpungeAfter defines how long a visitor is active before it is removed from memory. This number
 	// has to be very high to prevent e-mail abuse, but it doesn't really affect the other limits anyway, since
 	// they are replenished faster (typically).
-	visitorExpungeAfter = 24 * time.Hour
+	visitorExpungeAfter = oneDay
 
 	// visitorDefaultReservationsLimit is the amount of topic names a user without a tier is allowed to reserve.
 	// This number is zero, and changing it may have unintended consequences in the web app, or otherwise
 	visitorDefaultReservationsLimit = int64(0)
 )
 
+// Constants used to convert a tier-user's MessageLimit (see user.Tier) into adequate request limiter
+// values (token bucket).
+//
+// Example: Assuming a user.Tier's MessageLimit is 10,000:
+// - the allowed burst is 500 (= 10,000 * 5%), which is < 1000 (the max)
+// - the replenish rate is 2 * 10,000 / 24 hours
+const (
+	visitorMessageToRequestLimitBurstRate       = 0.05
+	visitorMessageToRequestLimitBurstMax        = 1000
+	visitorMessageToRequestLimitReplenishFactor = 2
+)
+
+// Constants used to convert a tier-user's EmailLimit (see user.Tier) into adequate email limiter
+// values (token bucket). Example: Assuming a user.Tier's EmailLimit is 200, the allowed burst is
+// 40 (= 200 * 20%), which is <150 (the max).
+const (
+	visitorEmailLimitBurstRate = 0.2
+	visitorEmailLimitBurstMax  = 150
+)
+
 var (
 	errVisitorLimitReached = errors.New("limit reached")
 )
@@ -55,9 +78,13 @@ type visitorInfo struct {
 
 type visitorLimits struct {
 	Basis                    visitorLimitBasis
-	MessagesLimit            int64
-	MessagesExpiryDuration   time.Duration
-	EmailsLimit              int64
+	RequestLimitBurst        int
+	RequestLimitReplenish    rate.Limit
+	MessageLimit             int64
+	MessageExpiryDuration    time.Duration
+	EmailLimit               int64
+	EmailLimitBurst          int
+	EmailLimitReplenish      rate.Limit
 	ReservationsLimit        int64
 	AttachmentTotalSizeLimit int64
 	AttachmentFileSizeLimit  int64
@@ -173,7 +200,7 @@ func (v *visitor) SubscriptionAllowed() error {
 }
 
 func (v *visitor) AccountCreationAllowed() error {
-	if v.accountLimiter != nil && !v.accountLimiter.Allow() {
+	if v.accountLimiter == nil || (v.accountLimiter != nil && !v.accountLimiter.Allow()) {
 		return errVisitorLimitReached
 	}
 	return nil
@@ -242,31 +269,6 @@ func (v *visitor) SetUser(u *user.User) {
 	}
 }
 
-func (v *visitor) resetLimiters() {
-	log.Info("%s Resetting limiters for visitor", v.stringNoLock())
-	var messagesLimiter, bandwidthLimiter util.Limiter
-	var requestLimiter, emailsLimiter, accountLimiter *rate.Limiter
-	if v.user != nil && v.user.Tier != nil {
-		requestLimiter = rate.NewLimiter(dailyLimitToRate(v.user.Tier.MessagesLimit), v.config.VisitorRequestLimitBurst)
-		messagesLimiter = util.NewFixedLimiter(v.user.Tier.MessagesLimit)
-		emailsLimiter = rate.NewLimiter(dailyLimitToRate(v.user.Tier.EmailsLimit), v.config.VisitorEmailLimitBurst)
-		bandwidthLimiter = util.NewBytesLimiter(int(v.user.Tier.AttachmentBandwidthLimit), 24*time.Hour)
-	} else {
-		requestLimiter = rate.NewLimiter(rate.Every(v.config.VisitorRequestLimitReplenish), v.config.VisitorRequestLimitBurst)
-		messagesLimiter = nil // Message limit is governed by the requestLimiter
-		emailsLimiter = rate.NewLimiter(rate.Every(v.config.VisitorEmailLimitReplenish), v.config.VisitorEmailLimitBurst)
-		bandwidthLimiter = util.NewBytesLimiter(int(v.config.VisitorAttachmentDailyBandwidthLimit), 24*time.Hour)
-	}
-	if v.user == nil {
-		accountLimiter = rate.NewLimiter(rate.Every(v.config.VisitorAccountCreationLimitReplenish), v.config.VisitorAccountCreationLimitBurst)
-	}
-	v.requestLimiter = requestLimiter
-	v.messagesLimiter = messagesLimiter
-	v.emailsLimiter = emailsLimiter
-	v.bandwidthLimiter = bandwidthLimiter
-	v.accountLimiter = accountLimiter
-}
-
 // MaybeUserID returns the user ID of the visitor (if any). If this is an anonymous visitor,
 // an empty string is returned.
 func (v *visitor) MaybeUserID() string {
@@ -278,22 +280,71 @@ func (v *visitor) MaybeUserID() string {
 	return ""
 }
 
+func (v *visitor) resetLimiters() {
+	log.Debug("%s Resetting limiters for visitor", v.stringNoLock())
+	limits := v.limitsNoLock()
+	v.requestLimiter = rate.NewLimiter(limits.RequestLimitReplenish, limits.RequestLimitBurst)
+	v.messagesLimiter = util.NewFixedLimiterWithValue(limits.MessageLimit, v.messages)
+	v.emailsLimiter = rate.NewLimiter(limits.EmailLimitReplenish, limits.EmailLimitBurst)
+	v.bandwidthLimiter = util.NewBytesLimiter(int(limits.AttachmentBandwidthLimit), oneDay)
+	if v.user == nil {
+		v.accountLimiter = rate.NewLimiter(rate.Every(v.config.VisitorAccountCreationLimitReplenish), v.config.VisitorAccountCreationLimitBurst)
+	} else {
+		v.accountLimiter = nil // Users cannot create accounts when logged in
+	}
+}
+
 func (v *visitor) Limits() *visitorLimits {
 	v.mu.Lock()
 	defer v.mu.Unlock()
-	limits := defaultVisitorLimits(v.config)
+	return v.limitsNoLock()
+}
+
+func (v *visitor) limitsNoLock() *visitorLimits {
 	if v.user != nil && v.user.Tier != nil {
-		limits.Basis = visitorLimitBasisTier
-		limits.MessagesLimit = v.user.Tier.MessagesLimit
-		limits.MessagesExpiryDuration = v.user.Tier.MessagesExpiryDuration
-		limits.EmailsLimit = v.user.Tier.EmailsLimit
-		limits.ReservationsLimit = v.user.Tier.ReservationsLimit
-		limits.AttachmentTotalSizeLimit = v.user.Tier.AttachmentTotalSizeLimit
-		limits.AttachmentFileSizeLimit = v.user.Tier.AttachmentFileSizeLimit
-		limits.AttachmentExpiryDuration = v.user.Tier.AttachmentExpiryDuration
-		limits.AttachmentBandwidthLimit = v.user.Tier.AttachmentBandwidthLimit
+		return tierBasedVisitorLimits(v.config, v.user.Tier)
+	}
+	return configBasedVisitorLimits(v.config)
+}
+
+func tierBasedVisitorLimits(conf *Config, tier *user.Tier) *visitorLimits {
+	return &visitorLimits{
+		Basis:                    visitorLimitBasisTier,
+		RequestLimitBurst:        util.MinMax(int(float64(tier.MessageLimit)*visitorMessageToRequestLimitBurstRate), conf.VisitorRequestLimitBurst, visitorMessageToRequestLimitBurstMax),
+		RequestLimitReplenish:    dailyLimitToRate(tier.MessageLimit * visitorMessageToRequestLimitReplenishFactor),
+		MessageLimit:             tier.MessageLimit,
+		MessageExpiryDuration:    tier.MessageExpiryDuration,
+		EmailLimit:               tier.EmailLimit,
+		EmailLimitBurst:          util.MinMax(int(float64(tier.EmailLimit)*visitorEmailLimitBurstRate), conf.VisitorEmailLimitBurst, visitorEmailLimitBurstMax),
+		EmailLimitReplenish:      dailyLimitToRate(tier.EmailLimit),
+		ReservationsLimit:        tier.ReservationLimit,
+		AttachmentTotalSizeLimit: tier.AttachmentTotalSizeLimit,
+		AttachmentFileSizeLimit:  tier.AttachmentFileSizeLimit,
+		AttachmentExpiryDuration: tier.AttachmentExpiryDuration,
+		AttachmentBandwidthLimit: tier.AttachmentBandwidthLimit,
+	}
+}
+
+func configBasedVisitorLimits(conf *Config) *visitorLimits {
+	messagesLimit := replenishDurationToDailyLimit(conf.VisitorRequestLimitReplenish) // Approximation!
+	if conf.VisitorMessageDailyLimit > 0 {
+		messagesLimit = int64(conf.VisitorMessageDailyLimit)
+	}
+	return &visitorLimits{
+		Basis:                    visitorLimitBasisIP,
+		RequestLimitBurst:        conf.VisitorRequestLimitBurst,
+		RequestLimitReplenish:    rate.Every(conf.VisitorRequestLimitReplenish),
+		MessageLimit:             messagesLimit,
+		MessageExpiryDuration:    conf.CacheDuration,
+		EmailLimit:               replenishDurationToDailyLimit(conf.VisitorEmailLimitReplenish), // Approximation!
+		EmailLimitBurst:          conf.VisitorEmailLimitBurst,
+		EmailLimitReplenish:      rate.Every(conf.VisitorEmailLimitReplenish),
+		ReservationsLimit:        visitorDefaultReservationsLimit,
+		AttachmentTotalSizeLimit: conf.VisitorAttachmentTotalSizeLimit,
+		AttachmentFileSizeLimit:  conf.AttachmentFileSizeLimit,
+		AttachmentExpiryDuration: conf.AttachmentExpiryDuration,
+		AttachmentBandwidthLimit: conf.VisitorAttachmentDailyBandwidthLimit,
 	}
-	return limits
 }
 
 func (v *visitor) Info() (*visitorInfo, error) {
@@ -321,9 +372,9 @@ func (v *visitor) Info() (*visitorInfo, error) {
 	limits := v.Limits()
 	stats := &visitorStats{
 		Messages:                     messages,
-		MessagesRemaining:            zeroIfNegative(limits.MessagesLimit - messages),
+		MessagesRemaining:            zeroIfNegative(limits.MessageLimit - messages),
 		Emails:                       emails,
-		EmailsRemaining:              zeroIfNegative(limits.EmailsLimit - emails),
+		EmailsRemaining:              zeroIfNegative(limits.EmailLimit - emails),
 		Reservations:                 reservations,
 		ReservationsRemaining:        zeroIfNegative(limits.ReservationsLimit - reservations),
 		AttachmentTotalSize:          attachmentsBytesUsed,
@@ -343,23 +394,16 @@ func zeroIfNegative(value int64) int64 {
 }
 
 func replenishDurationToDailyLimit(duration time.Duration) int64 {
-	return int64(24 * time.Hour / duration)
+	return int64(oneDay / duration)
 }
 
 func dailyLimitToRate(limit int64) rate.Limit {
-	return rate.Limit(limit) * rate.Every(24*time.Hour)
+	return rate.Limit(limit) * rate.Every(oneDay)
 }
 
-func defaultVisitorLimits(conf *Config) *visitorLimits {
-	return &visitorLimits{
-		Basis:                    visitorLimitBasisIP,
-		MessagesLimit:            replenishDurationToDailyLimit(conf.VisitorRequestLimitReplenish),
-		MessagesExpiryDuration:   conf.CacheDuration,
-		EmailsLimit:              replenishDurationToDailyLimit(conf.VisitorEmailLimitReplenish),
-		ReservationsLimit:        visitorDefaultReservationsLimit,
-		AttachmentTotalSizeLimit: conf.VisitorAttachmentTotalSizeLimit,
-		AttachmentFileSizeLimit:  conf.AttachmentFileSizeLimit,
-		AttachmentExpiryDuration: conf.AttachmentExpiryDuration,
-		AttachmentBandwidthLimit: conf.VisitorAttachmentDailyBandwidthLimit,
+func visitorID(ip netip.Addr, u *user.User) string {
+	if u != nil && u.Tier != nil {
+		return fmt.Sprintf("user:%s", u.ID)
 	}
+	return fmt.Sprintf("ip:%s", ip.String())
 }

+ 11 - 11
user/manager.go

@@ -709,10 +709,10 @@ func (a *Manager) readUser(rows *sql.Rows) (*User, error) {
 			ID:                       tierID.String,
 			Code:                     tierCode.String,
 			Name:                     tierName.String,
-			MessagesLimit:            messagesLimit.Int64,
-			MessagesExpiryDuration:   time.Duration(messagesExpiryDuration.Int64) * time.Second,
-			EmailsLimit:              emailsLimit.Int64,
-			ReservationsLimit:        reservationsLimit.Int64,
+			MessageLimit:             messagesLimit.Int64,
+			MessageExpiryDuration:    time.Duration(messagesExpiryDuration.Int64) * time.Second,
+			EmailLimit:               emailsLimit.Int64,
+			ReservationLimit:         reservationsLimit.Int64,
 			AttachmentFileSizeLimit:  attachmentFileSizeLimit.Int64,
 			AttachmentTotalSizeLimit: attachmentTotalSizeLimit.Int64,
 			AttachmentExpiryDuration: time.Duration(attachmentExpiryDuration.Int64) * time.Second,
@@ -845,7 +845,7 @@ func (a *Manager) ChangeTier(username, tier string) error {
 	t, err := a.Tier(tier)
 	if err != nil {
 		return err
-	} else if err := a.checkReservationsLimit(username, t.ReservationsLimit); err != nil {
+	} else if err := a.checkReservationsLimit(username, t.ReservationLimit); err != nil {
 		return err
 	}
 	if _, err := a.db.Exec(updateUserTierQuery, tier, username); err != nil {
@@ -870,7 +870,7 @@ func (a *Manager) checkReservationsLimit(username string, reservationsLimit int6
 	if err != nil {
 		return err
 	}
-	if u.Tier != nil && reservationsLimit < u.Tier.ReservationsLimit {
+	if u.Tier != nil && reservationsLimit < u.Tier.ReservationLimit {
 		reservations, err := a.Reservations(username)
 		if err != nil {
 			return err
@@ -999,7 +999,7 @@ func (a *Manager) CreateTier(tier *Tier) error {
 	if tier.ID == "" {
 		tier.ID = util.RandomStringPrefix(tierIDPrefix, tierIDLength)
 	}
-	if _, err := a.db.Exec(insertTierQuery, tier.ID, tier.Code, tier.Name, tier.MessagesLimit, int64(tier.MessagesExpiryDuration.Seconds()), tier.EmailsLimit, tier.ReservationsLimit, tier.AttachmentFileSizeLimit, tier.AttachmentTotalSizeLimit, int64(tier.AttachmentExpiryDuration.Seconds()), tier.AttachmentBandwidthLimit, tier.StripePriceID); err != nil {
+	if _, err := a.db.Exec(insertTierQuery, tier.ID, tier.Code, tier.Name, tier.MessageLimit, int64(tier.MessageExpiryDuration.Seconds()), tier.EmailLimit, tier.ReservationLimit, tier.AttachmentFileSizeLimit, tier.AttachmentTotalSizeLimit, int64(tier.AttachmentExpiryDuration.Seconds()), tier.AttachmentBandwidthLimit, tier.StripePriceID); err != nil {
 		return err
 	}
 	return nil
@@ -1070,10 +1070,10 @@ func (a *Manager) readTier(rows *sql.Rows) (*Tier, error) {
 		ID:                       id,
 		Code:                     code,
 		Name:                     name,
-		MessagesLimit:            messagesLimit.Int64,
-		MessagesExpiryDuration:   time.Duration(messagesExpiryDuration.Int64) * time.Second,
-		EmailsLimit:              emailsLimit.Int64,
-		ReservationsLimit:        reservationsLimit.Int64,
+		MessageLimit:             messagesLimit.Int64,
+		MessageExpiryDuration:    time.Duration(messagesExpiryDuration.Int64) * time.Second,
+		EmailLimit:               emailsLimit.Int64,
+		ReservationLimit:         reservationsLimit.Int64,
 		AttachmentFileSizeLimit:  attachmentFileSizeLimit.Int64,
 		AttachmentTotalSizeLimit: attachmentTotalSizeLimit.Int64,
 		AttachmentExpiryDuration: time.Duration(attachmentExpiryDuration.Int64) * time.Second,

+ 8 - 8
user/manager_test.go

@@ -335,10 +335,10 @@ func TestManager_ChangeRoleFromTierUserToAdmin(t *testing.T) {
 		Code:                     "pro",
 		Name:                     "ntfy Pro",
 		StripePriceID:            "price123",
-		MessagesLimit:            5_000,
-		MessagesExpiryDuration:   3 * 24 * time.Hour,
-		EmailsLimit:              50,
-		ReservationsLimit:        5,
+		MessageLimit:             5_000,
+		MessageExpiryDuration:    3 * 24 * time.Hour,
+		EmailLimit:               50,
+		ReservationLimit:         5,
 		AttachmentFileSizeLimit:  52428800,
 		AttachmentTotalSizeLimit: 524288000,
 		AttachmentExpiryDuration: 24 * time.Hour,
@@ -351,10 +351,10 @@ func TestManager_ChangeRoleFromTierUserToAdmin(t *testing.T) {
 	require.Nil(t, err)
 	require.Equal(t, RoleUser, ben.Role)
 	require.Equal(t, "pro", ben.Tier.Code)
-	require.Equal(t, int64(5000), ben.Tier.MessagesLimit)
-	require.Equal(t, 3*24*time.Hour, ben.Tier.MessagesExpiryDuration)
-	require.Equal(t, int64(50), ben.Tier.EmailsLimit)
-	require.Equal(t, int64(5), ben.Tier.ReservationsLimit)
+	require.Equal(t, int64(5000), ben.Tier.MessageLimit)
+	require.Equal(t, 3*24*time.Hour, ben.Tier.MessageExpiryDuration)
+	require.Equal(t, int64(50), ben.Tier.EmailLimit)
+	require.Equal(t, int64(5), ben.Tier.ReservationLimit)
 	require.Equal(t, int64(52428800), ben.Tier.AttachmentFileSizeLimit)
 	require.Equal(t, int64(524288000), ben.Tier.AttachmentTotalSizeLimit)
 	require.Equal(t, 24*time.Hour, ben.Tier.AttachmentExpiryDuration)

+ 4 - 4
user/types.go

@@ -62,10 +62,10 @@ type Tier struct {
 	ID                       string        // Tier identifier (ti_...)
 	Code                     string        // Code of the tier
 	Name                     string        // Name of the tier
-	MessagesLimit            int64         // Daily message limit
-	MessagesExpiryDuration   time.Duration // Cache duration for messages
-	EmailsLimit              int64         // Daily email limit
-	ReservationsLimit        int64         // Number of topic reservations allowed by user
+	MessageLimit             int64         // Daily message limit
+	MessageExpiryDuration    time.Duration // Cache duration for messages
+	EmailLimit               int64         // Daily email limit
+	ReservationLimit         int64         // Number of topic reservations allowed by user
 	AttachmentFileSizeLimit  int64         // Max file size per file (bytes)
 	AttachmentTotalSizeLimit int64         // Total file size for all files of this user (bytes)
 	AttachmentExpiryDuration time.Duration // Duration after which attachments will be deleted

+ 6 - 0
util/limit.go

@@ -27,8 +27,14 @@ type FixedLimiter struct {
 
 // NewFixedLimiter creates a new Limiter
 func NewFixedLimiter(limit int64) *FixedLimiter {
+	return NewFixedLimiterWithValue(limit, 0)
+}
+
+// NewFixedLimiterWithValue creates a new Limiter and sets the initial value
+func NewFixedLimiterWithValue(limit, value int64) *FixedLimiter {
 	return &FixedLimiter{
 		limit: limit,
+		value: value,
 	}
 }
 

+ 1 - 1
util/time.go

@@ -17,7 +17,7 @@ var (
 // NextOccurrenceUTC takes a time of day (e.g. 9:00am), and returns the next occurrence
 // of that time from the current time (in UTC).
 func NextOccurrenceUTC(timeOfDay, base time.Time) time.Time {
-	hour, minute, seconds := timeOfDay.Clock()
+	hour, minute, seconds := timeOfDay.UTC().Clock()
 	now := base.UTC()
 	next := time.Date(now.Year(), now.Month(), now.Day(), hour, minute, seconds, 0, time.UTC)
 	if next.Before(now) {

+ 11 - 0
util/util.go

@@ -337,6 +337,17 @@ func Retry[T any](f func() (*T, error), after ...time.Duration) (t *T, err error
 	return nil, err
 }
 
+// MinMax returns value if it is between min and max, or either
+// min or max if it is out of range
+func MinMax[T int | int64](value, min, max T) T {
+	if value < min {
+		return min
+	} else if value > max {
+		return max
+	}
+	return value
+}
+
 // String turns a string into a pointer of a string
 func String(v string) *string {
 	return &v