Parcourir la source

Rate limiting refactor, race fixes, more tests

binwiederhier il y a 3 ans
Parent
commit
62140ec001
8 fichiers modifiés avec 237 ajouts et 113 suppressions
  1. 12 22
      server/server.go
  2. 4 3
      server/server_account.go
  3. 63 5
      server/server_test.go
  4. 51 44
      server/visitor.go
  5. 3 3
      user/manager.go
  6. 3 3
      user/manager_test.go
  7. 73 15
      util/limit.go
  8. 28 18
      util/limit_test.go

+ 12 - 22
server/server.go

@@ -35,27 +35,19 @@ import (
 )
 
 /*
-TODO
---
 
 - HIGH Rate limiting: Sensitive endpoints (account/login/change-password/...)
-- HIGH Rate limiting: When ResetStats() is run, reset messagesLimiter (and others)?
-- MEDIUM Rate limiting: Test daily message quota read from database initially
 - MEDIUM: Races with v.user (see publishSyncEventAsync test)
+- MEDIUM: Test that anonymous user and user without tier are the same visitor
+- MEDIUM: Make sure account endpoints make sense for admins
 - MEDIUM: Reservation (UI): Show "This topic is reserved" error message when trying to reserve a reserved topic (Thorben)
 - MEDIUM: Reservation (UI): Ask for confirmation when removing reservation (deadcade)
 - MEDIUM: Reservation table delete button: dialog "keep or delete messages?"
+- MEDIUM: Tests for remaining payment endpoints
 - LOW: UI: Flickering upgrade banner when logging in
 - LOW: JS constants
 - LOW: Payments reconciliation process
 
-Limits & rate limiting:
-	users without tier: should the stats be persisted? are they meaningful? -> test that the visitor is based on the IP address!
-
-Make sure account endpoints make sense for admins
-
-Tests:
-- Payment endpoints (make mocks)
 */
 
 // Server is the main server, providing the UI and API for ntfy
@@ -513,7 +505,7 @@ func (s *Server) handleFile(w http.ResponseWriter, r *http.Request, v *visitor)
 		return errHTTPNotFound
 	}
 	if r.Method == http.MethodGet {
-		if err := v.BandwidthLimiter().Allow(stat.Size()); err != nil {
+		if !v.BandwidthAllowed(stat.Size()) {
 			return errHTTPTooManyRequestsLimitAttachmentBandwidth
 		}
 	}
@@ -543,7 +535,7 @@ func (s *Server) handlePublishWithoutResponse(r *http.Request, v *visitor) (*mes
 	if err != nil {
 		return nil, err
 	}
-	if err := v.MessageAllowed(); err != nil {
+	if !v.MessageAllowed() {
 		return nil, errHTTPTooManyRequestsLimitMessages
 	}
 	body, err := util.Peek(r.Body, s.config.MessageLimit)
@@ -558,9 +550,7 @@ func (s *Server) handlePublishWithoutResponse(r *http.Request, v *visitor) (*mes
 	if m.PollID != "" {
 		m = newPollRequestMessage(t.ID, m.PollID)
 	}
-	if v.user != nil {
-		m.User = v.user.ID
-	}
+	m.User = v.MaybeUserID()
 	m.Expires = time.Now().Add(v.Limits().MessageExpiryDuration).Unix()
 	if err := s.handlePublishBody(r, v, m, body, unifiedpush); err != nil {
 		return nil, err
@@ -582,7 +572,6 @@ func (s *Server) handlePublishWithoutResponse(r *http.Request, v *visitor) (*mes
 			go s.sendToFirebase(v, m)
 		}
 		if s.smtpSender != nil && email != "" {
-			v.IncrementEmails()
 			go s.sendEmail(v, m, email)
 		}
 		if s.config.UpstreamBaseURL != "" {
@@ -597,8 +586,9 @@ func (s *Server) handlePublishWithoutResponse(r *http.Request, v *visitor) (*mes
 			return nil, err
 		}
 	}
-	if s.userManager != nil && v.user != nil {
-		s.userManager.EnqueueStats(v.user.ID, v.Stats()) // FIXME this makes no sense for tier-less users
+	u := v.User()
+	if s.userManager != nil && u != nil && u.Tier != nil {
+		s.userManager.EnqueueStats(u.ID, v.Stats())
 	}
 	s.mu.Lock()
 	s.messages++
@@ -704,7 +694,7 @@ func (s *Server) parsePublishParams(r *http.Request, v *visitor, m *message) (ca
 	}
 	email = readParam(r, "x-email", "x-e-mail", "email", "e-mail", "mail", "e")
 	if email != "" {
-		if err := v.EmailAllowed(); err != nil {
+		if !v.EmailAllowed() {
 			return false, false, "", false, errHTTPTooManyRequestsLimitEmails
 		}
 	}
@@ -909,7 +899,7 @@ func (s *Server) handleSubscribeRaw(w http.ResponseWriter, r *http.Request, v *v
 func (s *Server) handleSubscribeHTTP(w http.ResponseWriter, r *http.Request, v *visitor, contentType string, encoder messageEncoder) error {
 	log.Debug("%s HTTP stream connection opened", logHTTPPrefix(v, r))
 	defer log.Debug("%s HTTP stream connection closed", logHTTPPrefix(v, r))
-	if err := v.SubscriptionAllowed(); err != nil {
+	if !v.SubscriptionAllowed() {
 		return errHTTPTooManyRequestsLimitSubscriptions
 	}
 	defer v.RemoveSubscription()
@@ -989,7 +979,7 @@ func (s *Server) handleSubscribeWS(w http.ResponseWriter, r *http.Request, v *vi
 	if strings.ToLower(r.Header.Get("Upgrade")) != "websocket" {
 		return errHTTPBadRequestWebSocketsUpgradeHeaderMissing
 	}
-	if err := v.SubscriptionAllowed(); err != nil {
+	if !v.SubscriptionAllowed() {
 		return errHTTPTooManyRequestsLimitSubscriptions
 	}
 	defer v.RemoveSubscription()

+ 4 - 3
server/server_account.go

@@ -23,7 +23,7 @@ func (s *Server) handleAccountCreate(w http.ResponseWriter, r *http.Request, v *
 		} else if v.user != nil {
 			return errHTTPUnauthorized // Cannot create account from user context
 		}
-		if err := v.AccountCreationAllowed(); err != nil {
+		if !v.AccountCreationAllowed() {
 			return errHTTPTooManyRequestsLimitAccountCreation
 		}
 	}
@@ -428,11 +428,12 @@ func (s *Server) publishSyncEvent(v *visitor) error {
 
 func (s *Server) publishSyncEventAsync(v *visitor) {
 	go func() {
-		if v.user == nil || v.user.SyncTopic == "" {
+		u := v.User()
+		if u == nil || u.SyncTopic == "" {
 			return
 		}
 		if err := s.publishSyncEvent(v); err != nil {
-			log.Trace("Error publishing to user %s's sync topic %s: %s", v.user.Name, v.user.SyncTopic, err.Error())
+			log.Trace("Error publishing to user %s's sync topic %s: %s", u.Name, u.SyncTopic, err.Error())
 		}
 	}()
 }

+ 63 - 5
server/server_test.go

@@ -841,23 +841,35 @@ func TestServer_StatsResetter(t *testing.T) {
 	require.Equal(t, int64(0), account.Stats.Messages)
 }
 
-func TestServer_StatsResetter_MessageLimiter(t *testing.T) {
-	// This tests that the messageLimiter (the only fixed limiter) is reset by the stats resetter
+func TestServer_StatsResetter_MessageLimiter_EmailsLimiter(t *testing.T) {
+	// This tests that the messageLimiter (the only fixed limiter) and the emailsLimiter (token bucket)
+	// is reset by the stats resetter
 
 	c := newTestConfigWithAuthFile(t)
 	s := newTestServer(t, c)
+	s.smtpSender = &testMailer{}
 
 	// Publish some messages, and check stats
 	for i := 0; i < 3; i++ {
 		response := request(t, s, "PUT", "/mytopic", "test", nil)
 		require.Equal(t, 200, response.Code)
 	}
+	response := request(t, s, "PUT", "/mytopic", "test", map[string]string{
+		"Email": "test@email.com",
+	})
+	require.Equal(t, 200, response.Code)
+
 	rr := request(t, s, "GET", "/v1/account", "", nil)
 	require.Equal(t, 200, rr.Code)
 	account, err := util.UnmarshalJSON[apiAccountResponse](io.NopCloser(rr.Body))
 	require.Nil(t, err)
-	require.Equal(t, int64(3), account.Stats.Messages)
-	require.Equal(t, int64(3), s.visitor(netip.MustParseAddr("9.9.9.9"), nil).messagesLimiter.Value())
+	require.Equal(t, int64(4), account.Stats.Messages)
+	require.Equal(t, int64(1), account.Stats.Emails)
+	v := s.visitor(netip.MustParseAddr("9.9.9.9"), nil)
+	require.Equal(t, int64(4), v.Stats().Messages)
+	require.Equal(t, int64(4), v.messagesLimiter.Value())
+	require.Equal(t, int64(1), v.Stats().Emails)
+	require.Equal(t, int64(1), v.emailsLimiter.Value())
 
 	// Reset stats and check again
 	s.resetStats()
@@ -866,7 +878,53 @@ func TestServer_StatsResetter_MessageLimiter(t *testing.T) {
 	account, err = util.UnmarshalJSON[apiAccountResponse](io.NopCloser(rr.Body))
 	require.Nil(t, err)
 	require.Equal(t, int64(0), account.Stats.Messages)
-	require.Equal(t, int64(0), s.visitor(netip.MustParseAddr("9.9.9.9"), nil).messagesLimiter.Value())
+	require.Equal(t, int64(0), account.Stats.Emails)
+	v = s.visitor(netip.MustParseAddr("9.9.9.9"), nil)
+	require.Equal(t, int64(0), v.Stats().Messages)
+	require.Equal(t, int64(0), v.messagesLimiter.Value())
+	require.Equal(t, int64(0), v.Stats().Emails)
+	require.Equal(t, int64(0), v.emailsLimiter.Value())
+}
+
+func TestServer_DailyMessageQuotaFromDatabase(t *testing.T) {
+	// This tests that the daily message quota is prefilled originally from the database,
+	// if the visitor is unknown
+
+	c := newTestConfigWithAuthFile(t)
+	s := newTestServer(t, c)
+	var err error
+	s.userManager, err = user.NewManagerWithStatsInterval(c.AuthFile, c.AuthStartupQueries, c.AuthDefault, 100*time.Millisecond)
+	require.Nil(t, err)
+
+	// Create user, and update it with some message and email stats
+	require.Nil(t, s.userManager.CreateTier(&user.Tier{
+		Code: "test",
+	}))
+	require.Nil(t, s.userManager.AddUser("phil", "phil", user.RoleUser))
+	require.Nil(t, s.userManager.ChangeTier("phil", "test"))
+
+	u, err := s.userManager.User("phil")
+	require.Nil(t, err)
+	s.userManager.EnqueueStats(u.ID, &user.Stats{
+		Messages: 123456,
+		Emails:   999,
+	})
+	time.Sleep(400 * time.Millisecond)
+
+	// Get account and verify stats are read from the DB, and that the visitor also has these stats
+	rr := request(t, s, "GET", "/v1/account", "", map[string]string{
+		"Authorization": util.BasicAuth("phil", "phil"),
+	})
+	require.Equal(t, 200, rr.Code)
+	account, err := util.UnmarshalJSON[apiAccountResponse](io.NopCloser(rr.Body))
+	require.Nil(t, err)
+	require.Equal(t, int64(123456), account.Stats.Messages)
+	require.Equal(t, int64(999), account.Stats.Emails)
+	v := s.visitor(netip.MustParseAddr("9.9.9.9"), u)
+	require.Equal(t, int64(123456), v.Stats().Messages)
+	require.Equal(t, int64(123456), v.messagesLimiter.Value())
+	require.Equal(t, int64(999), v.Stats().Emails)
+	require.Equal(t, int64(999), v.emailsLimiter.Value())
 }
 
 type testMailer struct {

+ 51 - 44
server/visitor.go

@@ -58,12 +58,11 @@ type visitor struct {
 	userManager         *user.Manager      // May be nil
 	ip                  netip.Addr         // Visitor IP address
 	user                *user.User         // Only set if authenticated user, otherwise nil
-	emails              int64              // Number of emails sent, reset every day
 	requestLimiter      *rate.Limiter      // Rate limiter for (almost) all requests (including messages)
 	messagesLimiter     *util.FixedLimiter // Rate limiter for messages
-	emailsLimiter       *rate.Limiter      // Rate limiter for emails
-	subscriptionLimiter util.Limiter       // Fixed limiter for active subscriptions (ongoing connections)
-	bandwidthLimiter    util.Limiter       // Limiter for attachment bandwidth downloads
+	emailsLimiter       *util.RateLimiter  // Rate limiter for emails
+	subscriptionLimiter *util.FixedLimiter // Fixed limiter for active subscriptions (ongoing connections)
+	bandwidthLimiter    *util.RateLimiter  // Limiter for attachment bandwidth downloads
 	accountLimiter      *rate.Limiter      // Rate limiter for account creation, may be nil
 	firebase            time.Time          // Next allowed Firebase message
 	seen                time.Time          // Last seen time of this visitor (needed for removal of stale visitors)
@@ -123,7 +122,6 @@ func newVisitor(conf *Config, messageCache *messageCache, userManager *user.Mana
 		userManager:         userManager, // May be nil
 		ip:                  ip,
 		user:                user,
-		emails:              emails,
 		firebase:            time.Unix(0, 0),
 		seen:                time.Now(),
 		subscriptionLimiter: util.NewFixedLimiter(int64(conf.VisitorSubscriptionLimit)),
@@ -133,7 +131,7 @@ func newVisitor(conf *Config, messageCache *messageCache, userManager *user.Mana
 		bandwidthLimiter:    nil, // Set in resetLimiters
 		accountLimiter:      nil, // Set in resetLimiters, may be nil
 	}
-	v.resetLimiters(messages)
+	v.resetLimitersNoLock(messages, emails)
 	return v
 }
 
@@ -153,6 +151,8 @@ func (v *visitor) stringNoLock() string {
 }
 
 func (v *visitor) RequestAllowed() error {
+	v.mu.Lock() // limiters could be replaced!
+	defer v.mu.Unlock()
 	if !v.requestLimiter.Allow() {
 		return errVisitorLimitReached
 	}
@@ -174,40 +174,43 @@ func (v *visitor) FirebaseTemporarilyDeny() {
 	v.firebase = time.Now().Add(v.config.FirebaseQuotaExceededPenaltyDuration)
 }
 
-func (v *visitor) MessageAllowed() error {
-	if v.messagesLimiter.Allow(1) != nil {
-		return errVisitorLimitReached
-	}
-	return nil
+func (v *visitor) MessageAllowed() bool {
+	v.mu.Lock() // limiters could be replaced!
+	defer v.mu.Unlock()
+	return v.messagesLimiter.Allow()
 }
 
-func (v *visitor) EmailAllowed() error {
-	if !v.emailsLimiter.Allow() {
-		return errVisitorLimitReached
-	}
-	return nil
+func (v *visitor) EmailAllowed() bool {
+	v.mu.Lock() // limiters could be replaced!
+	defer v.mu.Unlock()
+	return v.emailsLimiter.Allow()
 }
 
-func (v *visitor) SubscriptionAllowed() error {
-	v.mu.Lock()
+func (v *visitor) SubscriptionAllowed() bool {
+	v.mu.Lock() // limiters could be replaced!
 	defer v.mu.Unlock()
-	if err := v.subscriptionLimiter.Allow(1); err != nil {
-		return errVisitorLimitReached
-	}
-	return nil
+	return v.subscriptionLimiter.Allow()
 }
 
-func (v *visitor) AccountCreationAllowed() error {
+func (v *visitor) AccountCreationAllowed() bool {
+	v.mu.Lock() // limiters could be replaced!
+	defer v.mu.Unlock()
 	if v.accountLimiter == nil || (v.accountLimiter != nil && !v.accountLimiter.Allow()) {
-		return errVisitorLimitReached
+		return false
 	}
-	return nil
+	return true
+}
+
+func (v *visitor) BandwidthAllowed(bytes int64) bool {
+	v.mu.Lock() // limiters could be replaced!
+	defer v.mu.Unlock()
+	return v.bandwidthLimiter.AllowN(bytes)
 }
 
 func (v *visitor) RemoveSubscription() {
 	v.mu.Lock()
 	defer v.mu.Unlock()
-	v.subscriptionLimiter.Allow(-1)
+	v.subscriptionLimiter.AllowN(-1)
 }
 
 func (v *visitor) Keepalive() {
@@ -217,6 +220,8 @@ func (v *visitor) Keepalive() {
 }
 
 func (v *visitor) BandwidthLimiter() util.Limiter {
+	v.mu.Lock() // limiters could be replaced!
+	defer v.mu.Unlock()
 	return v.bandwidthLimiter
 }
 
@@ -226,28 +231,29 @@ func (v *visitor) Stale() bool {
 	return time.Since(v.seen) > visitorExpungeAfter
 }
 
-func (v *visitor) IncrementEmails() {
-	v.mu.Lock()
-	defer v.mu.Unlock()
-	v.emails++
-}
-
 func (v *visitor) Stats() *user.Stats {
-	v.mu.Lock()
+	v.mu.Lock() // limiters could be replaced!
 	defer v.mu.Unlock()
 	return &user.Stats{
 		Messages: v.messagesLimiter.Value(),
-		Emails:   v.emails,
+		Emails:   v.emailsLimiter.Value(),
 	}
 }
 
 func (v *visitor) ResetStats() {
-	v.mu.Lock()
+	v.mu.Lock() // limiters could be replaced!
 	defer v.mu.Unlock()
-	v.emails = 0
+	v.emailsLimiter.Reset()
 	v.messagesLimiter.Reset()
 }
 
+// User returns the visitor user, or nil if there is none
+func (v *visitor) User() *user.User {
+	v.mu.Lock()
+	defer v.mu.Unlock()
+	return v.user // May be nil
+}
+
 // SetUser sets the visitors user to the given value
 func (v *visitor) SetUser(u *user.User) {
 	v.mu.Lock()
@@ -255,7 +261,7 @@ func (v *visitor) SetUser(u *user.User) {
 	shouldResetLimiters := v.user.TierID() != u.TierID() // TierID works with nil receiver
 	v.user = u
 	if shouldResetLimiters {
-		v.resetLimiters(0)
+		v.resetLimitersNoLock(0, 0)
 	}
 }
 
@@ -270,12 +276,12 @@ func (v *visitor) MaybeUserID() string {
 	return ""
 }
 
-func (v *visitor) resetLimiters(messages int64) {
+func (v *visitor) resetLimitersNoLock(messages, emails int64) {
 	log.Debug("%s Resetting limiters for visitor", v.stringNoLock())
 	limits := v.limitsNoLock()
 	v.requestLimiter = rate.NewLimiter(limits.RequestLimitReplenish, limits.RequestLimitBurst)
 	v.messagesLimiter = util.NewFixedLimiterWithValue(limits.MessageLimit, messages)
-	v.emailsLimiter = rate.NewLimiter(limits.EmailLimitReplenish, limits.EmailLimitBurst)
+	v.emailsLimiter = util.NewRateLimiterWithValue(limits.EmailLimitReplenish, limits.EmailLimitBurst, emails)
 	v.bandwidthLimiter = util.NewBytesLimiter(int(limits.AttachmentBandwidthLimit), oneDay)
 	if v.user == nil {
 		v.accountLimiter = rate.NewLimiter(rate.Every(v.config.VisitorAccountCreationLimitReplenish), v.config.VisitorAccountCreationLimitBurst)
@@ -340,12 +346,13 @@ func configBasedVisitorLimits(conf *Config) *visitorLimits {
 func (v *visitor) Info() (*visitorInfo, error) {
 	v.mu.Lock()
 	messages := v.messagesLimiter.Value()
-	emails := v.emails
+	emails := v.emailsLimiter.Value()
 	v.mu.Unlock()
 	var attachmentsBytesUsed int64
 	var err error
-	if v.user != nil {
-		attachmentsBytesUsed, err = v.messageCache.AttachmentBytesUsedByUser(v.user.ID)
+	u := v.User()
+	if u != nil {
+		attachmentsBytesUsed, err = v.messageCache.AttachmentBytesUsedByUser(u.ID)
 	} else {
 		attachmentsBytesUsed, err = v.messageCache.AttachmentBytesUsedBySender(v.ip.String())
 	}
@@ -353,8 +360,8 @@ func (v *visitor) Info() (*visitorInfo, error) {
 		return nil, err
 	}
 	var reservations int64
-	if v.user != nil && v.userManager != nil {
-		reservations, err = v.userManager.ReservationsCount(v.user.Name)
+	if v.userManager != nil && u != nil {
+		reservations, err = v.userManager.ReservationsCount(u.Name)
 		if err != nil {
 			return nil, err
 		}

+ 3 - 3
user/manager.go

@@ -301,11 +301,11 @@ var _ Auther = (*Manager)(nil)
 
 // NewManager creates a new Manager instance
 func NewManager(filename, startupQueries string, defaultAccess Permission) (*Manager, error) {
-	return newManager(filename, startupQueries, defaultAccess, userStatsQueueWriterInterval)
+	return NewManagerWithStatsInterval(filename, startupQueries, defaultAccess, userStatsQueueWriterInterval)
 }
 
-// NewManager creates a new Manager instance
-func newManager(filename, startupQueries string, defaultAccess Permission, statsWriterInterval time.Duration) (*Manager, error) {
+// NewManagerWithStatsInterval creates a new Manager instance
+func NewManagerWithStatsInterval(filename, startupQueries string, defaultAccess Permission, statsWriterInterval time.Duration) (*Manager, error) {
 	db, err := sql.Open("sqlite3", filename)
 	if err != nil {
 		return nil, err

+ 3 - 3
user/manager_test.go

@@ -545,7 +545,7 @@ func TestManager_Token_MaxCount_AutoDelete(t *testing.T) {
 }
 
 func TestManager_EnqueueStats(t *testing.T) {
-	a, err := newManager(filepath.Join(t.TempDir(), "db"), "", PermissionReadWrite, 1500*time.Millisecond)
+	a, err := NewManagerWithStatsInterval(filepath.Join(t.TempDir(), "db"), "", PermissionReadWrite, 1500*time.Millisecond)
 	require.Nil(t, err)
 	require.Nil(t, a.AddUser("ben", "ben", RoleUser))
 
@@ -575,7 +575,7 @@ func TestManager_EnqueueStats(t *testing.T) {
 }
 
 func TestManager_ChangeSettings(t *testing.T) {
-	a, err := newManager(filepath.Join(t.TempDir(), "db"), "", PermissionReadWrite, 1500*time.Millisecond)
+	a, err := NewManagerWithStatsInterval(filepath.Join(t.TempDir(), "db"), "", PermissionReadWrite, 1500*time.Millisecond)
 	require.Nil(t, err)
 	require.Nil(t, a.AddUser("ben", "ben", RoleUser))
 
@@ -718,7 +718,7 @@ func newTestManager(t *testing.T, defaultAccess Permission) *Manager {
 }
 
 func newTestManagerFromFile(t *testing.T, filename, startupQueries string, defaultAccess Permission, statsWriterInterval time.Duration) *Manager {
-	a, err := newManager(filename, startupQueries, defaultAccess, statsWriterInterval)
+	a, err := NewManagerWithStatsInterval(filename, startupQueries, defaultAccess, statsWriterInterval)
 	require.Nil(t, err)
 	return a
 }

+ 73 - 15
util/limit.go

@@ -13,8 +13,17 @@ var ErrLimitReached = errors.New("limit reached")
 
 // Limiter is an interface that implements a rate limiting mechanism, e.g. based on time or a fixed value
 type Limiter interface {
-	// Allow adds n to the limiters internal value, or returns ErrLimitReached if the limit has been reached
-	Allow(n int64) error
+	// Allow adds one to the limiters value, or returns false if the limit has been reached
+	Allow() bool
+
+	// AllowN adds n to the limiters value, or returns false if the limit has been reached
+	AllowN(n int64) bool
+
+	// Value returns the current internal limiter value
+	Value() int64
+
+	// Reset resets the state of the limiter
+	Reset()
 }
 
 // FixedLimiter is a helper that allows adding values up to a well-defined limit. Once the limit is reached
@@ -25,6 +34,8 @@ type FixedLimiter struct {
 	mu    sync.Mutex
 }
 
+var _ Limiter = (*FixedLimiter)(nil)
+
 // NewFixedLimiter creates a new Limiter
 func NewFixedLimiter(limit int64) *FixedLimiter {
 	return NewFixedLimiterWithValue(limit, 0)
@@ -38,16 +49,22 @@ func NewFixedLimiterWithValue(limit, value int64) *FixedLimiter {
 	}
 }
 
-// Allow adds n to the limiters internal value, but only if the limit has not been reached. If the limit was
-// exceeded after adding n, ErrLimitReached is returned.
-func (l *FixedLimiter) Allow(n int64) error {
+// Allow adds one to the limiters internal value, but only if the limit has not been reached. If the limit was
+// exceeded, false is returned.
+func (l *FixedLimiter) Allow() bool {
+	return l.AllowN(1)
+}
+
+// AllowN adds n to the limiters internal value, but only if the limit has not been reached. If the limit was
+// exceeded after adding n, false is returned.
+func (l *FixedLimiter) AllowN(n int64) bool {
 	l.mu.Lock()
 	defer l.mu.Unlock()
 	if l.value+n > l.limit {
-		return ErrLimitReached
+		return false
 	}
 	l.value += n
-	return nil
+	return true
 }
 
 // Value returns the current limiter value
@@ -66,12 +83,29 @@ func (l *FixedLimiter) Reset() {
 
 // RateLimiter is a Limiter that wraps a rate.Limiter, allowing a floating time-based limit.
 type RateLimiter struct {
+	r       rate.Limit
+	b       int
+	value   int64
 	limiter *rate.Limiter
+	mu      sync.Mutex
 }
 
+var _ Limiter = (*RateLimiter)(nil)
+
 // NewRateLimiter creates a new RateLimiter
 func NewRateLimiter(r rate.Limit, b int) *RateLimiter {
+	return NewRateLimiterWithValue(r, b, 0)
+}
+
+// NewRateLimiterWithValue creates a new RateLimiter with the given starting value.
+//
+// Note that the starting value only has informational value. It does not impact the underlying
+// value of the rate.Limiter.
+func NewRateLimiterWithValue(r rate.Limit, b int, value int64) *RateLimiter {
 	return &RateLimiter{
+		r:       r,
+		b:       b,
+		value:   value,
 		limiter: rate.NewLimiter(r, b),
 	}
 }
@@ -82,16 +116,40 @@ func NewBytesLimiter(bytes int, interval time.Duration) *RateLimiter {
 	return NewRateLimiter(rate.Limit(bytes)*rate.Every(interval), bytes)
 }
 
-// Allow adds n to the limiters internal value, but only if the limit has not been reached. If the limit was
-// exceeded after adding n, ErrLimitReached is returned.
-func (l *RateLimiter) Allow(n int64) error {
+// Allow adds one to the limiters internal value, but only if the limit has not been reached. If the limit was
+// exceeded, false is returned.
+func (l *RateLimiter) Allow() bool {
+	return l.AllowN(1)
+}
+
+// AllowN adds n to the limiters internal value, but only if the limit has not been reached. If the limit was
+// exceeded after adding n, false is returned.
+func (l *RateLimiter) AllowN(n int64) bool {
 	if n <= 0 {
-		return nil // No-op. Can't take back bytes you're written!
+		return false // No-op. Can't take back bytes you're written!
 	}
+	l.mu.Lock()
+	defer l.mu.Unlock()
 	if !l.limiter.AllowN(time.Now(), int(n)) {
-		return ErrLimitReached
+		return false
 	}
-	return nil
+	l.value += n
+	return true
+}
+
+// Value returns the current limiter value
+func (l *RateLimiter) Value() int64 {
+	l.mu.Lock()
+	defer l.mu.Unlock()
+	return l.value
+}
+
+// Reset sets the limiter's value back to zero, and resets the underlying rate.Limiter
+func (l *RateLimiter) Reset() {
+	l.mu.Lock()
+	defer l.mu.Unlock()
+	l.limiter = rate.NewLimiter(l.r, l.b)
+	l.value = 0
 }
 
 // LimitWriter implements an io.Writer that will pass through all Write calls to the underlying
@@ -117,9 +175,9 @@ func (w *LimitWriter) Write(p []byte) (n int, err error) {
 	w.mu.Lock()
 	defer w.mu.Unlock()
 	for i := 0; i < len(w.limiters); i++ {
-		if err := w.limiters[i].Allow(int64(len(p))); err != nil {
+		if !w.limiters[i].AllowN(int64(len(p))) {
 			for j := i - 1; j >= 0; j-- {
-				w.limiters[j].Allow(-int64(len(p))) // Revert limiters limits if allowed
+				w.limiters[j].AllowN(-int64(len(p))) // Revert limiters limits if not allowed
 			}
 			return 0, ErrLimitReached
 		}

+ 28 - 18
util/limit_test.go

@@ -7,26 +7,31 @@ import (
 	"time"
 )
 
-func TestFixedLimiter_Add(t *testing.T) {
+func TestFixedLimiter_AllowValueReset(t *testing.T) {
 	l := NewFixedLimiter(10)
-	if err := l.Allow(5); err != nil {
-		t.Fatal(err)
-	}
-	if err := l.Allow(5); err != nil {
-		t.Fatal(err)
-	}
-	if err := l.Allow(5); err != ErrLimitReached {
-		t.Fatalf("expected ErrLimitReached, got %#v", err)
-	}
+	require.True(t, l.AllowN(5))
+	require.Equal(t, int64(5), l.Value())
+
+	require.True(t, l.AllowN(5))
+	require.Equal(t, int64(10), l.Value())
+
+	require.False(t, l.Allow())
+	require.Equal(t, int64(10), l.Value())
+
+	l.Reset()
+	require.Equal(t, int64(0), l.Value())
+	require.True(t, l.Allow())
+	require.True(t, l.AllowN(9))
+	require.False(t, l.Allow())
 }
 
 func TestFixedLimiter_AddSub(t *testing.T) {
 	l := NewFixedLimiter(10)
-	l.Allow(5)
+	l.AllowN(5)
 	if l.value != 5 {
 		t.Fatalf("expected value to be %d, got %d", 5, l.value)
 	}
-	l.Allow(-2)
+	l.AllowN(-2)
 	if l.value != 3 {
 		t.Fatalf("expected value to be %d, got %d", 7, l.value)
 	}
@@ -34,17 +39,22 @@ func TestFixedLimiter_AddSub(t *testing.T) {
 
 func TestBytesLimiter_Add_Simple(t *testing.T) {
 	l := NewBytesLimiter(250*1024*1024, 24*time.Hour) // 250 MB per 24h
-	require.Nil(t, l.Allow(100*1024*1024))
-	require.Nil(t, l.Allow(100*1024*1024))
-	require.Equal(t, ErrLimitReached, l.Allow(300*1024*1024))
+	require.True(t, l.AllowN(100*1024*1024))
+	require.Equal(t, int64(100*1024*1024), l.Value())
+
+	require.True(t, l.AllowN(100*1024*1024))
+	require.Equal(t, int64(200*1024*1024), l.Value())
+
+	require.False(t, l.AllowN(300*1024*1024))
+	require.Equal(t, int64(200*1024*1024), l.Value())
 }
 
 func TestBytesLimiter_Add_Wait(t *testing.T) {
 	l := NewBytesLimiter(250*1024*1024, 24*time.Hour) // 250 MB per 24h (~ 303 bytes per 100ms)
-	require.Nil(t, l.Allow(250*1024*1024))
-	require.Equal(t, ErrLimitReached, l.Allow(400))
+	require.True(t, l.AllowN(250*1024*1024))
+	require.False(t, l.AllowN(400))
 	time.Sleep(200 * time.Millisecond)
-	require.Nil(t, l.Allow(400))
+	require.True(t, l.AllowN(400))
 }
 
 func TestLimitWriter_WriteNoLimiter(t *testing.T) {