Browse Source

Payment checkout test, rate limit resetting on tier change; failing

binwiederhier 3 years ago
parent
commit
593e0748a8
8 changed files with 257 additions and 42 deletions
  1. 6 6
      server/config.go
  2. 3 5
      server/server.go
  3. 1 1
      server/server_account.go
  4. 72 0
      server/server_account_test.go
  5. 105 0
      server/server_payments_test.go
  6. 50 22
      server/visitor.go
  7. 11 8
      user/manager.go
  8. 9 0
      user/types.go

+ 6 - 6
server/config.go

@@ -46,8 +46,8 @@ const (
 	DefaultVisitorRequestLimitReplenish         = 5 * time.Second
 	DefaultVisitorRequestLimitReplenish         = 5 * time.Second
 	DefaultVisitorEmailLimitBurst               = 16
 	DefaultVisitorEmailLimitBurst               = 16
 	DefaultVisitorEmailLimitReplenish           = time.Hour
 	DefaultVisitorEmailLimitReplenish           = time.Hour
-	DefaultVisitorAccountCreateLimitBurst       = 3
-	DefaultVisitorAccountCreateLimitReplenish   = 24 * time.Hour
+	DefaultVisitorAccountCreationLimitBurst     = 3
+	DefaultVisitorAccountCreationLimitReplenish = 24 * time.Hour
 	DefaultVisitorAttachmentTotalSizeLimit      = 100 * 1024 * 1024 // 100 MB
 	DefaultVisitorAttachmentTotalSizeLimit      = 100 * 1024 * 1024 // 100 MB
 	DefaultVisitorAttachmentDailyBandwidthLimit = 500 * 1024 * 1024 // 500 MB
 	DefaultVisitorAttachmentDailyBandwidthLimit = 500 * 1024 * 1024 // 500 MB
 )
 )
@@ -107,8 +107,8 @@ type Config struct {
 	VisitorRequestExemptIPAddrs          []netip.Prefix
 	VisitorRequestExemptIPAddrs          []netip.Prefix
 	VisitorEmailLimitBurst               int
 	VisitorEmailLimitBurst               int
 	VisitorEmailLimitReplenish           time.Duration
 	VisitorEmailLimitReplenish           time.Duration
-	VisitorAccountCreateLimitBurst       int
-	VisitorAccountCreateLimitReplenish   time.Duration
+	VisitorAccountCreationLimitBurst     int
+	VisitorAccountCreationLimitReplenish time.Duration
 	VisitorStatsResetTime                time.Time // Time of the day at which to reset visitor stats
 	VisitorStatsResetTime                time.Time // Time of the day at which to reset visitor stats
 	BehindProxy                          bool
 	BehindProxy                          bool
 	StripeSecretKey                      string
 	StripeSecretKey                      string
@@ -173,8 +173,8 @@ func NewConfig() *Config {
 		VisitorRequestExemptIPAddrs:          make([]netip.Prefix, 0),
 		VisitorRequestExemptIPAddrs:          make([]netip.Prefix, 0),
 		VisitorEmailLimitBurst:               DefaultVisitorEmailLimitBurst,
 		VisitorEmailLimitBurst:               DefaultVisitorEmailLimitBurst,
 		VisitorEmailLimitReplenish:           DefaultVisitorEmailLimitReplenish,
 		VisitorEmailLimitReplenish:           DefaultVisitorEmailLimitReplenish,
-		VisitorAccountCreateLimitBurst:       DefaultVisitorAccountCreateLimitBurst,
-		VisitorAccountCreateLimitReplenish:   DefaultVisitorAccountCreateLimitReplenish,
+		VisitorAccountCreationLimitBurst:     DefaultVisitorAccountCreationLimitBurst,
+		VisitorAccountCreationLimitReplenish: DefaultVisitorAccountCreationLimitReplenish,
 		VisitorStatsResetTime:                DefaultVisitorStatsResetTime,
 		VisitorStatsResetTime:                DefaultVisitorStatsResetTime,
 		BehindProxy:                          false,
 		BehindProxy:                          false,
 		StripeSecretKey:                      "",
 		StripeSecretKey:                      "",

+ 3 - 5
server/server.go

@@ -40,6 +40,8 @@ TODO
 
 
 - HIGH Rate limiting: dailyLimitToRate is wrong? + TESTS
 - HIGH Rate limiting: dailyLimitToRate is wrong? + TESTS
 - HIGH Rate limiting: Sensitive endpoints (account/login/change-password/...)
 - HIGH Rate limiting: Sensitive endpoints (account/login/change-password/...)
+- HIGH Rate limiting: Delete visitor when tier is changed to refresh rate limiters
+- HIGH Rate limiting: When ResetStats() is run, reset messagesLimiter (and others)?
 - MEDIUM: Races with v.user (see publishSyncEventAsync test)
 - MEDIUM: Races with v.user (see publishSyncEventAsync test)
 - MEDIUM: Reservation (UI): Show "This topic is reserved" error message when trying to reserve a reserved topic (Thorben)
 - MEDIUM: Reservation (UI): Show "This topic is reserved" error message when trying to reserve a reserved topic (Thorben)
 - MEDIUM: Reservation (UI): Ask for confirmation when removing reservation (deadcade)
 - MEDIUM: Reservation (UI): Ask for confirmation when removing reservation (deadcade)
@@ -50,8 +52,6 @@ TODO
 
 
 Limits & rate limiting:
 Limits & rate limiting:
 	users without tier: should the stats be persisted? are they meaningful? -> test that the visitor is based on the IP address!
 	users without tier: should the stats be persisted? are they meaningful? -> test that the visitor is based on the IP address!
-	when ResetStats() is run, reset messagesLimiter (and others)?
-	Delete visitor when tier is changed to refresh rate limiters
 
 
 Make sure account endpoints make sense for admins
 Make sure account endpoints make sense for admins
 
 
@@ -1602,9 +1602,7 @@ func (s *Server) visitor(r *http.Request) (v *visitor, err error) {
 	} else {
 	} else {
 		v = s.visitorFromIP(ip)
 		v = s.visitorFromIP(ip)
 	}
 	}
-	v.mu.Lock()
-	v.user = u
-	v.mu.Unlock()
+	v.SetUser(u)  // Update visitor user with latest from database!
 	return v, err // Always return visitor, even when error occurs!
 	return v, err // Always return visitor, even when error occurs!
 }
 }
 
 

+ 1 - 1
server/server_account.go

@@ -31,7 +31,7 @@ func (s *Server) handleAccountCreate(w http.ResponseWriter, r *http.Request, v *
 	if existingUser, _ := s.userManager.User(newAccount.Username); existingUser != nil {
 	if existingUser, _ := s.userManager.User(newAccount.Username); existingUser != nil {
 		return errHTTPConflictUserExists
 		return errHTTPConflictUserExists
 	}
 	}
-	if v.accountLimiter != nil && !v.accountLimiter.Allow() {
+	if err := v.AccountCreationAllowed(); err != nil {
 		return errHTTPTooManyRequestsLimitAccountCreation
 		return errHTTPTooManyRequestsLimitAccountCreation
 	}
 	}
 	if err := s.userManager.AddUser(newAccount.Username, newAccount.Password, user.RoleUser); err != nil { // TODO this should return a User
 	if err := s.userManager.AddUser(newAccount.Username, newAccount.Password, user.RoleUser); err != nil { // TODO this should return a User

+ 72 - 0
server/server_account_test.go

@@ -6,6 +6,7 @@ import (
 	"heckel.io/ntfy/user"
 	"heckel.io/ntfy/user"
 	"heckel.io/ntfy/util"
 	"heckel.io/ntfy/util"
 	"io"
 	"io"
+	"strings"
 	"testing"
 	"testing"
 	"time"
 	"time"
 )
 )
@@ -91,6 +92,20 @@ func TestAccount_Signup_Disabled(t *testing.T) {
 	require.Equal(t, 40022, toHTTPError(t, rr.Body.String()).Code)
 	require.Equal(t, 40022, toHTTPError(t, rr.Body.String()).Code)
 }
 }
 
 
+func TestAccount_Signup_Rate_Limit(t *testing.T) {
+	conf := newTestConfigWithAuthFile(t)
+	conf.EnableSignup = true
+	s := newTestServer(t, conf)
+
+	for i := 0; i < 3; i++ {
+		rr := request(t, s, "POST", "/v1/account", fmt.Sprintf(`{"username":"phil%d", "password":"mypass"}`, i), nil)
+		require.Equal(t, 200, rr.Code, "failed on iteration %d", i)
+	}
+	rr := request(t, s, "POST", "/v1/account", `{"username":"notallowed", "password":"mypass"}`, nil)
+	require.Equal(t, 429, rr.Code)
+	require.Equal(t, 42906, toHTTPError(t, rr.Body.String()).Code)
+}
+
 func TestAccount_Get_Anonymous(t *testing.T) {
 func TestAccount_Get_Anonymous(t *testing.T) {
 	conf := newTestConfigWithAuthFile(t)
 	conf := newTestConfigWithAuthFile(t)
 	conf.VisitorRequestLimitReplenish = 86 * time.Second
 	conf.VisitorRequestLimitReplenish = 86 * time.Second
@@ -567,3 +582,60 @@ func TestAccount_Reservation_Add_Kills_Other_Subscribers(t *testing.T) {
 	s.topics["mytopic"].CancelSubscribers("<invalid>")
 	s.topics["mytopic"].CancelSubscribers("<invalid>")
 	<-userCh
 	<-userCh
 }
 }
+
+func TestAccount_Tier_Create(t *testing.T) {
+	conf := newTestConfigWithAuthFile(t)
+	s := newTestServer(t, conf)
+
+	// Create tier and user
+	require.Nil(t, s.userManager.CreateTier(&user.Tier{
+		Code:                     "pro",
+		Name:                     "Pro",
+		MessagesLimit:            123,
+		MessagesExpiryDuration:   86400 * time.Second,
+		EmailsLimit:              32,
+		ReservationsLimit:        2,
+		AttachmentFileSizeLimit:  1231231,
+		AttachmentTotalSizeLimit: 123123,
+		AttachmentExpiryDuration: 10800 * time.Second,
+		AttachmentBandwidthLimit: 21474836480,
+	}))
+	require.Nil(t, s.userManager.AddUser("phil", "phil", user.RoleUser))
+	require.Nil(t, s.userManager.ChangeTier("phil", "pro"))
+
+	ti, err := s.userManager.Tier("pro")
+	require.Nil(t, err)
+
+	u, err := s.userManager.User("phil")
+	require.Nil(t, err)
+
+	// These are populated by different SQL queries
+	require.Equal(t, ti, u.Tier)
+
+	// Fields
+	require.True(t, strings.HasPrefix(ti.ID, "ti_"))
+	require.Equal(t, "pro", ti.Code)
+	require.Equal(t, "Pro", ti.Name)
+	require.Equal(t, int64(123), ti.MessagesLimit)
+	require.Equal(t, 86400*time.Second, ti.MessagesExpiryDuration)
+	require.Equal(t, int64(32), ti.EmailsLimit)
+	require.Equal(t, int64(2), ti.ReservationsLimit)
+	require.Equal(t, int64(1231231), ti.AttachmentFileSizeLimit)
+	require.Equal(t, int64(123123), ti.AttachmentTotalSizeLimit)
+	require.Equal(t, 10800*time.Second, ti.AttachmentExpiryDuration)
+	require.Equal(t, int64(21474836480), ti.AttachmentBandwidthLimit)
+}
+
+func TestAccount_Tier_Create_With_ID(t *testing.T) {
+	conf := newTestConfigWithAuthFile(t)
+	s := newTestServer(t, conf)
+
+	require.Nil(t, s.userManager.CreateTier(&user.Tier{
+		ID:   "ti_123",
+		Code: "pro",
+	}))
+
+	ti, err := s.userManager.Tier("pro")
+	require.Nil(t, err)
+	require.Equal(t, "ti_123", ti.ID)
+}

+ 105 - 0
server/server_payments_test.go

@@ -133,6 +133,7 @@ func TestPayments_SubscriptionCreate_NotAStripeCustomer_Success(t *testing.T) {
 
 
 	// Create tier and user
 	// Create tier and user
 	require.Nil(t, s.userManager.CreateTier(&user.Tier{
 	require.Nil(t, s.userManager.CreateTier(&user.Tier{
+		ID:            "ti_123",
 		Code:          "pro",
 		Code:          "pro",
 		StripePriceID: "price_123",
 		StripePriceID: "price_123",
 	}))
 	}))
@@ -168,6 +169,7 @@ func TestPayments_SubscriptionCreate_StripeCustomer_Success(t *testing.T) {
 
 
 	// Create tier and user
 	// Create tier and user
 	require.Nil(t, s.userManager.CreateTier(&user.Tier{
 	require.Nil(t, s.userManager.CreateTier(&user.Tier{
+		ID:            "ti_123",
 		Code:          "pro",
 		Code:          "pro",
 		StripePriceID: "price_123",
 		StripePriceID: "price_123",
 	}))
 	}))
@@ -209,6 +211,7 @@ func TestPayments_AccountDelete_Cancels_Subscription(t *testing.T) {
 
 
 	// Create tier and user
 	// Create tier and user
 	require.Nil(t, s.userManager.CreateTier(&user.Tier{
 	require.Nil(t, s.userManager.CreateTier(&user.Tier{
+		ID:            "ti_123",
 		Code:          "pro",
 		Code:          "pro",
 		StripePriceID: "price_123",
 		StripePriceID: "price_123",
 	}))
 	}))
@@ -235,6 +238,106 @@ func TestPayments_AccountDelete_Cancels_Subscription(t *testing.T) {
 	require.Equal(t, 401, rr.Code)
 	require.Equal(t, 401, rr.Code)
 }
 }
 
 
+func TestPayments_Checkout_Success_And_Increase_Ratelimits_Reset_Visitor(t *testing.T) {
+	// This tests a successful checkout flow (not a paying customer -> paying customer),
+	// and also tests that during the upgrade we are RESETTING THE RATE LIMITS of the existing user.
+
+	stripeMock := &testStripeAPI{}
+	defer stripeMock.AssertExpectations(t)
+
+	c := newTestConfigWithAuthFile(t)
+	c.StripeSecretKey = "secret key"
+	c.StripeWebhookKey = "webhook key"
+	c.VisitorRequestLimitBurst = 10
+	c.VisitorRequestLimitReplenish = time.Hour
+	s := newTestServer(t, c)
+	s.stripe = stripeMock
+
+	// Create a user with a Stripe subscription and 3 reservations
+	require.Nil(t, s.userManager.CreateTier(&user.Tier{
+		ID:                     "ti_123",
+		Code:                   "starter",
+		StripePriceID:          "price_1234",
+		ReservationsLimit:      1,
+		MessagesLimit:          100,
+		MessagesExpiryDuration: time.Hour,
+	}))
+	require.Nil(t, s.userManager.AddUser("phil", "phil", user.RoleUser)) // No tier
+	u, err := s.userManager.User("phil")
+	require.Nil(t, err)
+
+	// Define how the mock should react
+	stripeMock.
+		On("GetSession", "SOMETOKEN").
+		Return(&stripe.CheckoutSession{
+			ClientReferenceID: u.ID, // ntfy user ID
+			Customer: &stripe.Customer{
+				ID: "acct_5555",
+			},
+			Subscription: &stripe.Subscription{
+				ID: "sub_1234",
+			},
+		}, nil)
+	stripeMock.
+		On("GetSubscription", "sub_1234").
+		Return(&stripe.Subscription{
+			ID:               "sub_1234",
+			Status:           stripe.SubscriptionStatusActive,
+			CurrentPeriodEnd: 123456789,
+			CancelAt:         0,
+			Items: &stripe.SubscriptionItemList{
+				Data: []*stripe.SubscriptionItem{
+					{
+						Price: &stripe.Price{ID: "price_1234"},
+					},
+				},
+			},
+		}, nil)
+	stripeMock.
+		On("UpdateCustomer", mock.Anything).
+		Return(&stripe.Customer{}, nil)
+
+	// Send messages until rate limit of free tier is hit
+	for i := 0; i < 10; i++ {
+		rr := request(t, s, "PUT", "/mytopic", "some message", map[string]string{
+			"Authorization": util.BasicAuth("phil", "phil"),
+		})
+		require.Equal(t, 200, rr.Code)
+	}
+	rr := request(t, s, "PUT", "/mytopic", "some message", map[string]string{
+		"Authorization": util.BasicAuth("phil", "phil"),
+	})
+	require.Equal(t, 429, rr.Code)
+
+	// Simulate Stripe success return URL call (no user context)
+	rr = request(t, s, "GET", "/v1/account/billing/subscription/success/SOMETOKEN", "", nil)
+	require.Equal(t, 303, rr.Code)
+
+	// Verify that database columns were updated
+	u, err = s.userManager.User("phil")
+	require.Nil(t, err)
+	require.Equal(t, "starter", u.Tier.Code) // Not "pro"
+	require.Equal(t, "acct_5555", u.Billing.StripeCustomerID)
+	require.Equal(t, "sub_1234", u.Billing.StripeSubscriptionID)
+	require.Equal(t, stripe.SubscriptionStatusActive, u.Billing.StripeSubscriptionStatus)
+	require.Equal(t, int64(123456789), u.Billing.StripeSubscriptionPaidUntil.Unix())
+	require.Equal(t, int64(0), u.Billing.StripeSubscriptionCancelAt.Unix())
+
+	// FIXME FIXME This test is broken, because the rate limit logic is unclear!
+
+	// Now for the fun part: Verify that new rate limits are immediately applied
+	for i := 0; i < 100; i++ {
+		rr := request(t, s, "PUT", "/mytopic", "some message", map[string]string{
+			"Authorization": util.BasicAuth("phil", "phil"),
+		})
+		require.Equal(t, 200, rr.Code, "failed on iteration %d", i)
+	}
+	rr = request(t, s, "PUT", "/mytopic", "some message", map[string]string{
+		"Authorization": util.BasicAuth("phil", "phil"),
+	})
+	require.Equal(t, 429, rr.Code)
+}
+
 func TestPayments_Webhook_Subscription_Updated_Downgrade_From_PastDue_To_Active(t *testing.T) {
 func TestPayments_Webhook_Subscription_Updated_Downgrade_From_PastDue_To_Active(t *testing.T) {
 	// This tests incoming webhooks from Stripe to update a subscription:
 	// This tests incoming webhooks from Stripe to update a subscription:
 	// - All Stripe columns are updated in the user table
 	// - All Stripe columns are updated in the user table
@@ -257,6 +360,7 @@ func TestPayments_Webhook_Subscription_Updated_Downgrade_From_PastDue_To_Active(
 
 
 	// Create a user with a Stripe subscription and 3 reservations
 	// Create a user with a Stripe subscription and 3 reservations
 	require.Nil(t, s.userManager.CreateTier(&user.Tier{
 	require.Nil(t, s.userManager.CreateTier(&user.Tier{
+		ID:                       "ti_1",
 		Code:                     "starter",
 		Code:                     "starter",
 		StripePriceID:            "price_1234", // !
 		StripePriceID:            "price_1234", // !
 		ReservationsLimit:        1,            // !
 		ReservationsLimit:        1,            // !
@@ -268,6 +372,7 @@ func TestPayments_Webhook_Subscription_Updated_Downgrade_From_PastDue_To_Active(
 		AttachmentBandwidthLimit: 1000000,
 		AttachmentBandwidthLimit: 1000000,
 	}))
 	}))
 	require.Nil(t, s.userManager.CreateTier(&user.Tier{
 	require.Nil(t, s.userManager.CreateTier(&user.Tier{
+		ID:                       "ti_2",
 		Code:                     "pro",
 		Code:                     "pro",
 		StripePriceID:            "price_1111", // !
 		StripePriceID:            "price_1111", // !
 		ReservationsLimit:        3,            // !
 		ReservationsLimit:        3,            // !

+ 50 - 22
server/visitor.go

@@ -3,6 +3,7 @@ package server
 import (
 import (
 	"errors"
 	"errors"
 	"fmt"
 	"fmt"
+	"heckel.io/ntfy/log"
 	"heckel.io/ntfy/user"
 	"heckel.io/ntfy/user"
 	"net/netip"
 	"net/netip"
 	"sync"
 	"sync"
@@ -41,7 +42,7 @@ type visitor struct {
 	emailsLimiter       *rate.Limiter // Rate limiter for emails
 	emailsLimiter       *rate.Limiter // Rate limiter for emails
 	subscriptionLimiter util.Limiter  // Fixed limiter for active subscriptions (ongoing connections)
 	subscriptionLimiter util.Limiter  // Fixed limiter for active subscriptions (ongoing connections)
 	bandwidthLimiter    util.Limiter  // Limiter for attachment bandwidth downloads
 	bandwidthLimiter    util.Limiter  // Limiter for attachment bandwidth downloads
-	accountLimiter      *rate.Limiter // Rate limiter for account creation
+	accountLimiter      *rate.Limiter // Rate limiter for account creation, may be nil
 	firebase            time.Time     // Next allowed Firebase message
 	firebase            time.Time     // Next allowed Firebase message
 	seen                time.Time     // Last seen time of this visitor (needed for removal of stale visitors)
 	seen                time.Time     // Last seen time of this visitor (needed for removal of stale visitors)
 	mu                  sync.Mutex
 	mu                  sync.Mutex
@@ -85,26 +86,12 @@ const (
 )
 )
 
 
 func newVisitor(conf *Config, messageCache *messageCache, userManager *user.Manager, ip netip.Addr, user *user.User) *visitor {
 func newVisitor(conf *Config, messageCache *messageCache, userManager *user.Manager, ip netip.Addr, user *user.User) *visitor {
-	var messagesLimiter, attachmentBandwidthLimiter util.Limiter
-	var requestLimiter, emailsLimiter, accountLimiter *rate.Limiter
 	var messages, emails int64
 	var messages, emails int64
 	if user != nil {
 	if user != nil {
 		messages = user.Stats.Messages
 		messages = user.Stats.Messages
 		emails = user.Stats.Emails
 		emails = user.Stats.Emails
-	} else {
-		accountLimiter = rate.NewLimiter(rate.Every(conf.VisitorAccountCreateLimitReplenish), conf.VisitorAccountCreateLimitBurst)
 	}
 	}
-	if user != nil && user.Tier != nil {
-		requestLimiter = rate.NewLimiter(dailyLimitToRate(user.Tier.MessagesLimit), conf.VisitorRequestLimitBurst)
-		messagesLimiter = util.NewFixedLimiter(user.Tier.MessagesLimit)
-		emailsLimiter = rate.NewLimiter(dailyLimitToRate(user.Tier.EmailsLimit), conf.VisitorEmailLimitBurst)
-		attachmentBandwidthLimiter = util.NewBytesLimiter(int(user.Tier.AttachmentBandwidthLimit), 24*time.Hour)
-	} else {
-		requestLimiter = rate.NewLimiter(rate.Every(conf.VisitorRequestLimitReplenish), conf.VisitorRequestLimitBurst)
-		emailsLimiter = rate.NewLimiter(rate.Every(conf.VisitorEmailLimitReplenish), conf.VisitorEmailLimitBurst)
-		attachmentBandwidthLimiter = util.NewBytesLimiter(int(conf.VisitorAttachmentDailyBandwidthLimit), 24*time.Hour)
-	}
-	return &visitor{
+	v := &visitor{
 		config:              conf,
 		config:              conf,
 		messageCache:        messageCache,
 		messageCache:        messageCache,
 		userManager:         userManager, // May be nil
 		userManager:         userManager, // May be nil
@@ -112,20 +99,26 @@ func newVisitor(conf *Config, messageCache *messageCache, userManager *user.Mana
 		user:                user,
 		user:                user,
 		messages:            messages,
 		messages:            messages,
 		emails:              emails,
 		emails:              emails,
-		requestLimiter:      requestLimiter,
-		messagesLimiter:     messagesLimiter, // May be nil
-		emailsLimiter:       emailsLimiter,
-		subscriptionLimiter: util.NewFixedLimiter(int64(conf.VisitorSubscriptionLimit)),
-		bandwidthLimiter:    attachmentBandwidthLimiter,
-		accountLimiter:      accountLimiter, // May be nil
 		firebase:            time.Unix(0, 0),
 		firebase:            time.Unix(0, 0),
 		seen:                time.Now(),
 		seen:                time.Now(),
+		subscriptionLimiter: util.NewFixedLimiter(int64(conf.VisitorSubscriptionLimit)),
+		requestLimiter:      nil, // Set in resetLimiters
+		messagesLimiter:     nil, // Set in resetLimiters, may be nil
+		emailsLimiter:       nil, // Set in resetLimiters
+		bandwidthLimiter:    nil, // Set in resetLimiters
+		accountLimiter:      nil, // Set in resetLimiters, may be nil
 	}
 	}
+	v.resetLimiters()
+	return v
 }
 }
 
 
 func (v *visitor) String() string {
 func (v *visitor) String() string {
 	v.mu.Lock()
 	v.mu.Lock()
 	defer v.mu.Unlock()
 	defer v.mu.Unlock()
+	return v.stringNoLock()
+}
+
+func (v *visitor) stringNoLock() string {
 	if v.user != nil && v.user.Billing.StripeCustomerID != "" {
 	if v.user != nil && v.user.Billing.StripeCustomerID != "" {
 		return fmt.Sprintf("%s/%s/%s", v.ip.String(), v.user.ID, v.user.Billing.StripeCustomerID)
 		return fmt.Sprintf("%s/%s/%s", v.ip.String(), v.user.ID, v.user.Billing.StripeCustomerID)
 	} else if v.user != nil {
 	} else if v.user != nil {
@@ -179,6 +172,13 @@ func (v *visitor) SubscriptionAllowed() error {
 	return nil
 	return nil
 }
 }
 
 
+func (v *visitor) AccountCreationAllowed() error {
+	if v.accountLimiter != nil && !v.accountLimiter.Allow() {
+		return errVisitorLimitReached
+	}
+	return nil
+}
+
 func (v *visitor) RemoveSubscription() {
 func (v *visitor) RemoveSubscription() {
 	v.mu.Lock()
 	v.mu.Lock()
 	defer v.mu.Unlock()
 	defer v.mu.Unlock()
@@ -235,7 +235,35 @@ func (v *visitor) ResetStats() {
 func (v *visitor) SetUser(u *user.User) {
 func (v *visitor) SetUser(u *user.User) {
 	v.mu.Lock()
 	v.mu.Lock()
 	defer v.mu.Unlock()
 	defer v.mu.Unlock()
+	shouldResetLimiters := v.user.TierID() != u.TierID() // TierID works with nil receiver
 	v.user = u
 	v.user = u
+	if shouldResetLimiters {
+		v.resetLimiters()
+	}
+}
+
+func (v *visitor) resetLimiters() {
+	log.Info("%s Resetting limiters for visitor", v.stringNoLock())
+	var messagesLimiter, bandwidthLimiter util.Limiter
+	var requestLimiter, emailsLimiter, accountLimiter *rate.Limiter
+	if v.user != nil && v.user.Tier != nil {
+		requestLimiter = rate.NewLimiter(dailyLimitToRate(v.user.Tier.MessagesLimit), v.config.VisitorRequestLimitBurst)
+		messagesLimiter = util.NewFixedLimiter(v.user.Tier.MessagesLimit)
+		emailsLimiter = rate.NewLimiter(dailyLimitToRate(v.user.Tier.EmailsLimit), v.config.VisitorEmailLimitBurst)
+		bandwidthLimiter = util.NewBytesLimiter(int(v.user.Tier.AttachmentBandwidthLimit), 24*time.Hour)
+		accountLimiter = nil // A logged-in user cannot create an account
+	} else {
+		requestLimiter = rate.NewLimiter(rate.Every(v.config.VisitorRequestLimitReplenish), v.config.VisitorRequestLimitBurst)
+		messagesLimiter = nil // Message limit is governed by the requestLimiter
+		emailsLimiter = rate.NewLimiter(rate.Every(v.config.VisitorEmailLimitReplenish), v.config.VisitorEmailLimitBurst)
+		bandwidthLimiter = util.NewBytesLimiter(int(v.config.VisitorAttachmentDailyBandwidthLimit), 24*time.Hour)
+		accountLimiter = rate.NewLimiter(rate.Every(v.config.VisitorAccountCreationLimitReplenish), v.config.VisitorAccountCreationLimitBurst)
+	}
+	v.requestLimiter = requestLimiter
+	v.messagesLimiter = messagesLimiter
+	v.emailsLimiter = emailsLimiter
+	v.bandwidthLimiter = bandwidthLimiter
+	v.accountLimiter = accountLimiter
 }
 }
 
 
 // MaybeUserID returns the user ID of the visitor (if any). If this is an anonymous visitor,
 // MaybeUserID returns the user ID of the visitor (if any). If this is an anonymous visitor,

+ 11 - 8
user/manager.go

@@ -110,26 +110,26 @@ const (
 	`
 	`
 
 
 	selectUserByIDQuery = `
 	selectUserByIDQuery = `
-		SELECT u.id, u.user, u.pass, u.role, u.prefs, u.sync_topic, u.stats_messages, u.stats_emails, u.stripe_customer_id, u.stripe_subscription_id, u.stripe_subscription_status, u.stripe_subscription_paid_until, u.stripe_subscription_cancel_at, deleted, t.code, t.name, t.messages_limit, t.messages_expiry_duration, t.emails_limit, t.reservations_limit, t.attachment_file_size_limit, t.attachment_total_size_limit, t.attachment_expiry_duration, t.attachment_bandwidth_limit, t.stripe_price_id
+		SELECT u.id, u.user, u.pass, u.role, u.prefs, u.sync_topic, u.stats_messages, u.stats_emails, u.stripe_customer_id, u.stripe_subscription_id, u.stripe_subscription_status, u.stripe_subscription_paid_until, u.stripe_subscription_cancel_at, deleted, t.id, t.code, t.name, t.messages_limit, t.messages_expiry_duration, t.emails_limit, t.reservations_limit, t.attachment_file_size_limit, t.attachment_total_size_limit, t.attachment_expiry_duration, t.attachment_bandwidth_limit, t.stripe_price_id
 		FROM user u
 		FROM user u
 		LEFT JOIN tier t on t.id = u.tier_id
 		LEFT JOIN tier t on t.id = u.tier_id
 		WHERE u.id = ?		
 		WHERE u.id = ?		
 	`
 	`
 	selectUserByNameQuery = `
 	selectUserByNameQuery = `
-		SELECT u.id, u.user, u.pass, u.role, u.prefs, u.sync_topic, u.stats_messages, u.stats_emails, u.stripe_customer_id, u.stripe_subscription_id, u.stripe_subscription_status, u.stripe_subscription_paid_until, u.stripe_subscription_cancel_at, deleted, t.code, t.name, t.messages_limit, t.messages_expiry_duration, t.emails_limit, t.reservations_limit, t.attachment_file_size_limit, t.attachment_total_size_limit, t.attachment_expiry_duration, t.attachment_bandwidth_limit, t.stripe_price_id
+		SELECT u.id, u.user, u.pass, u.role, u.prefs, u.sync_topic, u.stats_messages, u.stats_emails, u.stripe_customer_id, u.stripe_subscription_id, u.stripe_subscription_status, u.stripe_subscription_paid_until, u.stripe_subscription_cancel_at, deleted, t.id, t.code, t.name, t.messages_limit, t.messages_expiry_duration, t.emails_limit, t.reservations_limit, t.attachment_file_size_limit, t.attachment_total_size_limit, t.attachment_expiry_duration, t.attachment_bandwidth_limit, t.stripe_price_id
 		FROM user u
 		FROM user u
 		LEFT JOIN tier t on t.id = u.tier_id
 		LEFT JOIN tier t on t.id = u.tier_id
 		WHERE user = ?
 		WHERE user = ?
 	`
 	`
 	selectUserByTokenQuery = `
 	selectUserByTokenQuery = `
-		SELECT u.id, u.user, u.pass, u.role, u.prefs, u.sync_topic, u.stats_messages, u.stats_emails, u.stripe_customer_id, u.stripe_subscription_id, u.stripe_subscription_status, u.stripe_subscription_paid_until, u.stripe_subscription_cancel_at, deleted, t.code, t.name, t.messages_limit, t.messages_expiry_duration, t.emails_limit, t.reservations_limit, t.attachment_file_size_limit, t.attachment_total_size_limit, t.attachment_expiry_duration, t.attachment_bandwidth_limit, t.stripe_price_id
+		SELECT u.id, u.user, u.pass, u.role, u.prefs, u.sync_topic, u.stats_messages, u.stats_emails, u.stripe_customer_id, u.stripe_subscription_id, u.stripe_subscription_status, u.stripe_subscription_paid_until, u.stripe_subscription_cancel_at, deleted, t.id, t.code, t.name, t.messages_limit, t.messages_expiry_duration, t.emails_limit, t.reservations_limit, t.attachment_file_size_limit, t.attachment_total_size_limit, t.attachment_expiry_duration, t.attachment_bandwidth_limit, t.stripe_price_id
 		FROM user u
 		FROM user u
 		JOIN user_token t on u.id = t.user_id
 		JOIN user_token t on u.id = t.user_id
 		LEFT JOIN tier t on t.id = u.tier_id
 		LEFT JOIN tier t on t.id = u.tier_id
 		WHERE t.token = ? AND t.expires >= ?
 		WHERE t.token = ? AND t.expires >= ?
 	`
 	`
 	selectUserByStripeCustomerIDQuery = `
 	selectUserByStripeCustomerIDQuery = `
-		SELECT u.id, u.user, u.pass, u.role, u.prefs, u.sync_topic, u.stats_messages, u.stats_emails, u.stripe_customer_id, u.stripe_subscription_id, u.stripe_subscription_status, u.stripe_subscription_paid_until, u.stripe_subscription_cancel_at, deleted, t.code, t.name, t.messages_limit, t.messages_expiry_duration, t.emails_limit, t.reservations_limit, t.attachment_file_size_limit, t.attachment_total_size_limit, t.attachment_expiry_duration, t.attachment_bandwidth_limit, t.stripe_price_id
+		SELECT u.id, u.user, u.pass, u.role, u.prefs, u.sync_topic, u.stats_messages, u.stats_emails, u.stripe_customer_id, u.stripe_subscription_id, u.stripe_subscription_status, u.stripe_subscription_paid_until, u.stripe_subscription_cancel_at, deleted, t.id, t.code, t.name, t.messages_limit, t.messages_expiry_duration, t.emails_limit, t.reservations_limit, t.attachment_file_size_limit, t.attachment_total_size_limit, t.attachment_expiry_duration, t.attachment_bandwidth_limit, t.stripe_price_id
 		FROM user u
 		FROM user u
 		LEFT JOIN tier t on t.id = u.tier_id
 		LEFT JOIN tier t on t.id = u.tier_id
 		WHERE u.stripe_customer_id = ?
 		WHERE u.stripe_customer_id = ?
@@ -669,13 +669,13 @@ func (a *Manager) userByToken(token string) (*User, error) {
 func (a *Manager) readUser(rows *sql.Rows) (*User, error) {
 func (a *Manager) readUser(rows *sql.Rows) (*User, error) {
 	defer rows.Close()
 	defer rows.Close()
 	var id, username, hash, role, prefs, syncTopic string
 	var id, username, hash, role, prefs, syncTopic string
-	var stripeCustomerID, stripeSubscriptionID, stripeSubscriptionStatus, stripePriceID, tierCode, tierName sql.NullString
+	var stripeCustomerID, stripeSubscriptionID, stripeSubscriptionStatus, stripePriceID, tierID, tierCode, tierName sql.NullString
 	var messages, emails int64
 	var messages, emails int64
 	var messagesLimit, messagesExpiryDuration, emailsLimit, reservationsLimit, attachmentFileSizeLimit, attachmentTotalSizeLimit, attachmentExpiryDuration, attachmentBandwidthLimit, stripeSubscriptionPaidUntil, stripeSubscriptionCancelAt, deleted sql.NullInt64
 	var messagesLimit, messagesExpiryDuration, emailsLimit, reservationsLimit, attachmentFileSizeLimit, attachmentTotalSizeLimit, attachmentExpiryDuration, attachmentBandwidthLimit, stripeSubscriptionPaidUntil, stripeSubscriptionCancelAt, deleted sql.NullInt64
 	if !rows.Next() {
 	if !rows.Next() {
 		return nil, ErrUserNotFound
 		return nil, ErrUserNotFound
 	}
 	}
-	if err := rows.Scan(&id, &username, &hash, &role, &prefs, &syncTopic, &messages, &emails, &stripeCustomerID, &stripeSubscriptionID, &stripeSubscriptionStatus, &stripeSubscriptionPaidUntil, &stripeSubscriptionCancelAt, &deleted, &tierCode, &tierName, &messagesLimit, &messagesExpiryDuration, &emailsLimit, &reservationsLimit, &attachmentFileSizeLimit, &attachmentTotalSizeLimit, &attachmentExpiryDuration, &attachmentBandwidthLimit, &stripePriceID); err != nil {
+	if err := rows.Scan(&id, &username, &hash, &role, &prefs, &syncTopic, &messages, &emails, &stripeCustomerID, &stripeSubscriptionID, &stripeSubscriptionStatus, &stripeSubscriptionPaidUntil, &stripeSubscriptionCancelAt, &deleted, &tierID, &tierCode, &tierName, &messagesLimit, &messagesExpiryDuration, &emailsLimit, &reservationsLimit, &attachmentFileSizeLimit, &attachmentTotalSizeLimit, &attachmentExpiryDuration, &attachmentBandwidthLimit, &stripePriceID); err != nil {
 		return nil, err
 		return nil, err
 	} else if err := rows.Err(); err != nil {
 	} else if err := rows.Err(); err != nil {
 		return nil, err
 		return nil, err
@@ -706,6 +706,7 @@ func (a *Manager) readUser(rows *sql.Rows) (*User, error) {
 	if tierCode.Valid {
 	if tierCode.Valid {
 		// See readTier() when this is changed!
 		// See readTier() when this is changed!
 		user.Tier = &Tier{
 		user.Tier = &Tier{
+			ID:                       tierID.String,
 			Code:                     tierCode.String,
 			Code:                     tierCode.String,
 			Name:                     tierName.String,
 			Name:                     tierName.String,
 			MessagesLimit:            messagesLimit.Int64,
 			MessagesLimit:            messagesLimit.Int64,
@@ -995,8 +996,10 @@ func (a *Manager) DefaultAccess() Permission {
 
 
 // CreateTier creates a new tier in the database
 // CreateTier creates a new tier in the database
 func (a *Manager) CreateTier(tier *Tier) error {
 func (a *Manager) CreateTier(tier *Tier) error {
-	tierID := util.RandomStringPrefix(tierIDPrefix, tierIDLength)
-	if _, err := a.db.Exec(insertTierQuery, tierID, tier.Code, tier.Name, tier.MessagesLimit, int64(tier.MessagesExpiryDuration.Seconds()), tier.EmailsLimit, tier.ReservationsLimit, tier.AttachmentFileSizeLimit, tier.AttachmentTotalSizeLimit, int64(tier.AttachmentExpiryDuration.Seconds()), tier.AttachmentBandwidthLimit, tier.StripePriceID); err != nil {
+	if tier.ID == "" {
+		tier.ID = util.RandomStringPrefix(tierIDPrefix, tierIDLength)
+	}
+	if _, err := a.db.Exec(insertTierQuery, tier.ID, tier.Code, tier.Name, tier.MessagesLimit, int64(tier.MessagesExpiryDuration.Seconds()), tier.EmailsLimit, tier.ReservationsLimit, tier.AttachmentFileSizeLimit, tier.AttachmentTotalSizeLimit, int64(tier.AttachmentExpiryDuration.Seconds()), tier.AttachmentBandwidthLimit, tier.StripePriceID); err != nil {
 		return err
 		return err
 	}
 	}
 	return nil
 	return nil

+ 9 - 0
user/types.go

@@ -23,6 +23,15 @@ type User struct {
 	Deleted   bool
 	Deleted   bool
 }
 }
 
 
+// TierID returns the ID of the User.Tier, or an empty string if the user has no tier,
+// or if the user itself is nil.
+func (u *User) TierID() string {
+	if u == nil || u.Tier == nil {
+		return ""
+	}
+	return u.Tier.ID
+}
+
 // Auther is an interface for authentication and authorization
 // Auther is an interface for authentication and authorization
 type Auther interface {
 type Auther interface {
 	// Authenticate checks username and password and returns a user if correct. The method
 	// Authenticate checks username and password and returns a user if correct. The method