Просмотр исходного кода

Payment checkout test, rate limit resetting on tier change; failing

binwiederhier 3 лет назад
Родитель
Сommit
593e0748a8
8 измененных файлов с 257 добавлено и 42 удалено
  1. 6 6
      server/config.go
  2. 3 5
      server/server.go
  3. 1 1
      server/server_account.go
  4. 72 0
      server/server_account_test.go
  5. 105 0
      server/server_payments_test.go
  6. 50 22
      server/visitor.go
  7. 11 8
      user/manager.go
  8. 9 0
      user/types.go

+ 6 - 6
server/config.go

@@ -46,8 +46,8 @@ const (
 	DefaultVisitorRequestLimitReplenish         = 5 * time.Second
 	DefaultVisitorEmailLimitBurst               = 16
 	DefaultVisitorEmailLimitReplenish           = time.Hour
-	DefaultVisitorAccountCreateLimitBurst       = 3
-	DefaultVisitorAccountCreateLimitReplenish   = 24 * time.Hour
+	DefaultVisitorAccountCreationLimitBurst     = 3
+	DefaultVisitorAccountCreationLimitReplenish = 24 * time.Hour
 	DefaultVisitorAttachmentTotalSizeLimit      = 100 * 1024 * 1024 // 100 MB
 	DefaultVisitorAttachmentDailyBandwidthLimit = 500 * 1024 * 1024 // 500 MB
 )
@@ -107,8 +107,8 @@ type Config struct {
 	VisitorRequestExemptIPAddrs          []netip.Prefix
 	VisitorEmailLimitBurst               int
 	VisitorEmailLimitReplenish           time.Duration
-	VisitorAccountCreateLimitBurst       int
-	VisitorAccountCreateLimitReplenish   time.Duration
+	VisitorAccountCreationLimitBurst     int
+	VisitorAccountCreationLimitReplenish time.Duration
 	VisitorStatsResetTime                time.Time // Time of the day at which to reset visitor stats
 	BehindProxy                          bool
 	StripeSecretKey                      string
@@ -173,8 +173,8 @@ func NewConfig() *Config {
 		VisitorRequestExemptIPAddrs:          make([]netip.Prefix, 0),
 		VisitorEmailLimitBurst:               DefaultVisitorEmailLimitBurst,
 		VisitorEmailLimitReplenish:           DefaultVisitorEmailLimitReplenish,
-		VisitorAccountCreateLimitBurst:       DefaultVisitorAccountCreateLimitBurst,
-		VisitorAccountCreateLimitReplenish:   DefaultVisitorAccountCreateLimitReplenish,
+		VisitorAccountCreationLimitBurst:     DefaultVisitorAccountCreationLimitBurst,
+		VisitorAccountCreationLimitReplenish: DefaultVisitorAccountCreationLimitReplenish,
 		VisitorStatsResetTime:                DefaultVisitorStatsResetTime,
 		BehindProxy:                          false,
 		StripeSecretKey:                      "",

+ 3 - 5
server/server.go

@@ -40,6 +40,8 @@ TODO
 
 - HIGH Rate limiting: dailyLimitToRate is wrong? + TESTS
 - HIGH Rate limiting: Sensitive endpoints (account/login/change-password/...)
+- HIGH Rate limiting: Delete visitor when tier is changed to refresh rate limiters
+- HIGH Rate limiting: When ResetStats() is run, reset messagesLimiter (and others)?
 - MEDIUM: Races with v.user (see publishSyncEventAsync test)
 - MEDIUM: Reservation (UI): Show "This topic is reserved" error message when trying to reserve a reserved topic (Thorben)
 - MEDIUM: Reservation (UI): Ask for confirmation when removing reservation (deadcade)
@@ -50,8 +52,6 @@ TODO
 
 Limits & rate limiting:
 	users without tier: should the stats be persisted? are they meaningful? -> test that the visitor is based on the IP address!
-	when ResetStats() is run, reset messagesLimiter (and others)?
-	Delete visitor when tier is changed to refresh rate limiters
 
 Make sure account endpoints make sense for admins
 
@@ -1602,9 +1602,7 @@ func (s *Server) visitor(r *http.Request) (v *visitor, err error) {
 	} else {
 		v = s.visitorFromIP(ip)
 	}
-	v.mu.Lock()
-	v.user = u
-	v.mu.Unlock()
+	v.SetUser(u)  // Update visitor user with latest from database!
 	return v, err // Always return visitor, even when error occurs!
 }
 

+ 1 - 1
server/server_account.go

@@ -31,7 +31,7 @@ func (s *Server) handleAccountCreate(w http.ResponseWriter, r *http.Request, v *
 	if existingUser, _ := s.userManager.User(newAccount.Username); existingUser != nil {
 		return errHTTPConflictUserExists
 	}
-	if v.accountLimiter != nil && !v.accountLimiter.Allow() {
+	if err := v.AccountCreationAllowed(); err != nil {
 		return errHTTPTooManyRequestsLimitAccountCreation
 	}
 	if err := s.userManager.AddUser(newAccount.Username, newAccount.Password, user.RoleUser); err != nil { // TODO this should return a User

+ 72 - 0
server/server_account_test.go

@@ -6,6 +6,7 @@ import (
 	"heckel.io/ntfy/user"
 	"heckel.io/ntfy/util"
 	"io"
+	"strings"
 	"testing"
 	"time"
 )
@@ -91,6 +92,20 @@ func TestAccount_Signup_Disabled(t *testing.T) {
 	require.Equal(t, 40022, toHTTPError(t, rr.Body.String()).Code)
 }
 
+func TestAccount_Signup_Rate_Limit(t *testing.T) {
+	conf := newTestConfigWithAuthFile(t)
+	conf.EnableSignup = true
+	s := newTestServer(t, conf)
+
+	for i := 0; i < 3; i++ {
+		rr := request(t, s, "POST", "/v1/account", fmt.Sprintf(`{"username":"phil%d", "password":"mypass"}`, i), nil)
+		require.Equal(t, 200, rr.Code, "failed on iteration %d", i)
+	}
+	rr := request(t, s, "POST", "/v1/account", `{"username":"notallowed", "password":"mypass"}`, nil)
+	require.Equal(t, 429, rr.Code)
+	require.Equal(t, 42906, toHTTPError(t, rr.Body.String()).Code)
+}
+
 func TestAccount_Get_Anonymous(t *testing.T) {
 	conf := newTestConfigWithAuthFile(t)
 	conf.VisitorRequestLimitReplenish = 86 * time.Second
@@ -567,3 +582,60 @@ func TestAccount_Reservation_Add_Kills_Other_Subscribers(t *testing.T) {
 	s.topics["mytopic"].CancelSubscribers("<invalid>")
 	<-userCh
 }
+
+func TestAccount_Tier_Create(t *testing.T) {
+	conf := newTestConfigWithAuthFile(t)
+	s := newTestServer(t, conf)
+
+	// Create tier and user
+	require.Nil(t, s.userManager.CreateTier(&user.Tier{
+		Code:                     "pro",
+		Name:                     "Pro",
+		MessagesLimit:            123,
+		MessagesExpiryDuration:   86400 * time.Second,
+		EmailsLimit:              32,
+		ReservationsLimit:        2,
+		AttachmentFileSizeLimit:  1231231,
+		AttachmentTotalSizeLimit: 123123,
+		AttachmentExpiryDuration: 10800 * time.Second,
+		AttachmentBandwidthLimit: 21474836480,
+	}))
+	require.Nil(t, s.userManager.AddUser("phil", "phil", user.RoleUser))
+	require.Nil(t, s.userManager.ChangeTier("phil", "pro"))
+
+	ti, err := s.userManager.Tier("pro")
+	require.Nil(t, err)
+
+	u, err := s.userManager.User("phil")
+	require.Nil(t, err)
+
+	// These are populated by different SQL queries
+	require.Equal(t, ti, u.Tier)
+
+	// Fields
+	require.True(t, strings.HasPrefix(ti.ID, "ti_"))
+	require.Equal(t, "pro", ti.Code)
+	require.Equal(t, "Pro", ti.Name)
+	require.Equal(t, int64(123), ti.MessagesLimit)
+	require.Equal(t, 86400*time.Second, ti.MessagesExpiryDuration)
+	require.Equal(t, int64(32), ti.EmailsLimit)
+	require.Equal(t, int64(2), ti.ReservationsLimit)
+	require.Equal(t, int64(1231231), ti.AttachmentFileSizeLimit)
+	require.Equal(t, int64(123123), ti.AttachmentTotalSizeLimit)
+	require.Equal(t, 10800*time.Second, ti.AttachmentExpiryDuration)
+	require.Equal(t, int64(21474836480), ti.AttachmentBandwidthLimit)
+}
+
+func TestAccount_Tier_Create_With_ID(t *testing.T) {
+	conf := newTestConfigWithAuthFile(t)
+	s := newTestServer(t, conf)
+
+	require.Nil(t, s.userManager.CreateTier(&user.Tier{
+		ID:   "ti_123",
+		Code: "pro",
+	}))
+
+	ti, err := s.userManager.Tier("pro")
+	require.Nil(t, err)
+	require.Equal(t, "ti_123", ti.ID)
+}

+ 105 - 0
server/server_payments_test.go

@@ -133,6 +133,7 @@ func TestPayments_SubscriptionCreate_NotAStripeCustomer_Success(t *testing.T) {
 
 	// Create tier and user
 	require.Nil(t, s.userManager.CreateTier(&user.Tier{
+		ID:            "ti_123",
 		Code:          "pro",
 		StripePriceID: "price_123",
 	}))
@@ -168,6 +169,7 @@ func TestPayments_SubscriptionCreate_StripeCustomer_Success(t *testing.T) {
 
 	// Create tier and user
 	require.Nil(t, s.userManager.CreateTier(&user.Tier{
+		ID:            "ti_123",
 		Code:          "pro",
 		StripePriceID: "price_123",
 	}))
@@ -209,6 +211,7 @@ func TestPayments_AccountDelete_Cancels_Subscription(t *testing.T) {
 
 	// Create tier and user
 	require.Nil(t, s.userManager.CreateTier(&user.Tier{
+		ID:            "ti_123",
 		Code:          "pro",
 		StripePriceID: "price_123",
 	}))
@@ -235,6 +238,106 @@ func TestPayments_AccountDelete_Cancels_Subscription(t *testing.T) {
 	require.Equal(t, 401, rr.Code)
 }
 
+func TestPayments_Checkout_Success_And_Increase_Ratelimits_Reset_Visitor(t *testing.T) {
+	// This tests a successful checkout flow (not a paying customer -> paying customer),
+	// and also tests that during the upgrade we are RESETTING THE RATE LIMITS of the existing user.
+
+	stripeMock := &testStripeAPI{}
+	defer stripeMock.AssertExpectations(t)
+
+	c := newTestConfigWithAuthFile(t)
+	c.StripeSecretKey = "secret key"
+	c.StripeWebhookKey = "webhook key"
+	c.VisitorRequestLimitBurst = 10
+	c.VisitorRequestLimitReplenish = time.Hour
+	s := newTestServer(t, c)
+	s.stripe = stripeMock
+
+	// Create a user with a Stripe subscription and 3 reservations
+	require.Nil(t, s.userManager.CreateTier(&user.Tier{
+		ID:                     "ti_123",
+		Code:                   "starter",
+		StripePriceID:          "price_1234",
+		ReservationsLimit:      1,
+		MessagesLimit:          100,
+		MessagesExpiryDuration: time.Hour,
+	}))
+	require.Nil(t, s.userManager.AddUser("phil", "phil", user.RoleUser)) // No tier
+	u, err := s.userManager.User("phil")
+	require.Nil(t, err)
+
+	// Define how the mock should react
+	stripeMock.
+		On("GetSession", "SOMETOKEN").
+		Return(&stripe.CheckoutSession{
+			ClientReferenceID: u.ID, // ntfy user ID
+			Customer: &stripe.Customer{
+				ID: "acct_5555",
+			},
+			Subscription: &stripe.Subscription{
+				ID: "sub_1234",
+			},
+		}, nil)
+	stripeMock.
+		On("GetSubscription", "sub_1234").
+		Return(&stripe.Subscription{
+			ID:               "sub_1234",
+			Status:           stripe.SubscriptionStatusActive,
+			CurrentPeriodEnd: 123456789,
+			CancelAt:         0,
+			Items: &stripe.SubscriptionItemList{
+				Data: []*stripe.SubscriptionItem{
+					{
+						Price: &stripe.Price{ID: "price_1234"},
+					},
+				},
+			},
+		}, nil)
+	stripeMock.
+		On("UpdateCustomer", mock.Anything).
+		Return(&stripe.Customer{}, nil)
+
+	// Send messages until rate limit of free tier is hit
+	for i := 0; i < 10; i++ {
+		rr := request(t, s, "PUT", "/mytopic", "some message", map[string]string{
+			"Authorization": util.BasicAuth("phil", "phil"),
+		})
+		require.Equal(t, 200, rr.Code)
+	}
+	rr := request(t, s, "PUT", "/mytopic", "some message", map[string]string{
+		"Authorization": util.BasicAuth("phil", "phil"),
+	})
+	require.Equal(t, 429, rr.Code)
+
+	// Simulate Stripe success return URL call (no user context)
+	rr = request(t, s, "GET", "/v1/account/billing/subscription/success/SOMETOKEN", "", nil)
+	require.Equal(t, 303, rr.Code)
+
+	// Verify that database columns were updated
+	u, err = s.userManager.User("phil")
+	require.Nil(t, err)
+	require.Equal(t, "starter", u.Tier.Code) // Not "pro"
+	require.Equal(t, "acct_5555", u.Billing.StripeCustomerID)
+	require.Equal(t, "sub_1234", u.Billing.StripeSubscriptionID)
+	require.Equal(t, stripe.SubscriptionStatusActive, u.Billing.StripeSubscriptionStatus)
+	require.Equal(t, int64(123456789), u.Billing.StripeSubscriptionPaidUntil.Unix())
+	require.Equal(t, int64(0), u.Billing.StripeSubscriptionCancelAt.Unix())
+
+	// FIXME FIXME This test is broken, because the rate limit logic is unclear!
+
+	// Now for the fun part: Verify that new rate limits are immediately applied
+	for i := 0; i < 100; i++ {
+		rr := request(t, s, "PUT", "/mytopic", "some message", map[string]string{
+			"Authorization": util.BasicAuth("phil", "phil"),
+		})
+		require.Equal(t, 200, rr.Code, "failed on iteration %d", i)
+	}
+	rr = request(t, s, "PUT", "/mytopic", "some message", map[string]string{
+		"Authorization": util.BasicAuth("phil", "phil"),
+	})
+	require.Equal(t, 429, rr.Code)
+}
+
 func TestPayments_Webhook_Subscription_Updated_Downgrade_From_PastDue_To_Active(t *testing.T) {
 	// This tests incoming webhooks from Stripe to update a subscription:
 	// - All Stripe columns are updated in the user table
@@ -257,6 +360,7 @@ func TestPayments_Webhook_Subscription_Updated_Downgrade_From_PastDue_To_Active(
 
 	// Create a user with a Stripe subscription and 3 reservations
 	require.Nil(t, s.userManager.CreateTier(&user.Tier{
+		ID:                       "ti_1",
 		Code:                     "starter",
 		StripePriceID:            "price_1234", // !
 		ReservationsLimit:        1,            // !
@@ -268,6 +372,7 @@ func TestPayments_Webhook_Subscription_Updated_Downgrade_From_PastDue_To_Active(
 		AttachmentBandwidthLimit: 1000000,
 	}))
 	require.Nil(t, s.userManager.CreateTier(&user.Tier{
+		ID:                       "ti_2",
 		Code:                     "pro",
 		StripePriceID:            "price_1111", // !
 		ReservationsLimit:        3,            // !

+ 50 - 22
server/visitor.go

@@ -3,6 +3,7 @@ package server
 import (
 	"errors"
 	"fmt"
+	"heckel.io/ntfy/log"
 	"heckel.io/ntfy/user"
 	"net/netip"
 	"sync"
@@ -41,7 +42,7 @@ type visitor struct {
 	emailsLimiter       *rate.Limiter // Rate limiter for emails
 	subscriptionLimiter util.Limiter  // Fixed limiter for active subscriptions (ongoing connections)
 	bandwidthLimiter    util.Limiter  // Limiter for attachment bandwidth downloads
-	accountLimiter      *rate.Limiter // Rate limiter for account creation
+	accountLimiter      *rate.Limiter // Rate limiter for account creation, may be nil
 	firebase            time.Time     // Next allowed Firebase message
 	seen                time.Time     // Last seen time of this visitor (needed for removal of stale visitors)
 	mu                  sync.Mutex
@@ -85,26 +86,12 @@ const (
 )
 
 func newVisitor(conf *Config, messageCache *messageCache, userManager *user.Manager, ip netip.Addr, user *user.User) *visitor {
-	var messagesLimiter, attachmentBandwidthLimiter util.Limiter
-	var requestLimiter, emailsLimiter, accountLimiter *rate.Limiter
 	var messages, emails int64
 	if user != nil {
 		messages = user.Stats.Messages
 		emails = user.Stats.Emails
-	} else {
-		accountLimiter = rate.NewLimiter(rate.Every(conf.VisitorAccountCreateLimitReplenish), conf.VisitorAccountCreateLimitBurst)
 	}
-	if user != nil && user.Tier != nil {
-		requestLimiter = rate.NewLimiter(dailyLimitToRate(user.Tier.MessagesLimit), conf.VisitorRequestLimitBurst)
-		messagesLimiter = util.NewFixedLimiter(user.Tier.MessagesLimit)
-		emailsLimiter = rate.NewLimiter(dailyLimitToRate(user.Tier.EmailsLimit), conf.VisitorEmailLimitBurst)
-		attachmentBandwidthLimiter = util.NewBytesLimiter(int(user.Tier.AttachmentBandwidthLimit), 24*time.Hour)
-	} else {
-		requestLimiter = rate.NewLimiter(rate.Every(conf.VisitorRequestLimitReplenish), conf.VisitorRequestLimitBurst)
-		emailsLimiter = rate.NewLimiter(rate.Every(conf.VisitorEmailLimitReplenish), conf.VisitorEmailLimitBurst)
-		attachmentBandwidthLimiter = util.NewBytesLimiter(int(conf.VisitorAttachmentDailyBandwidthLimit), 24*time.Hour)
-	}
-	return &visitor{
+	v := &visitor{
 		config:              conf,
 		messageCache:        messageCache,
 		userManager:         userManager, // May be nil
@@ -112,20 +99,26 @@ func newVisitor(conf *Config, messageCache *messageCache, userManager *user.Mana
 		user:                user,
 		messages:            messages,
 		emails:              emails,
-		requestLimiter:      requestLimiter,
-		messagesLimiter:     messagesLimiter, // May be nil
-		emailsLimiter:       emailsLimiter,
-		subscriptionLimiter: util.NewFixedLimiter(int64(conf.VisitorSubscriptionLimit)),
-		bandwidthLimiter:    attachmentBandwidthLimiter,
-		accountLimiter:      accountLimiter, // May be nil
 		firebase:            time.Unix(0, 0),
 		seen:                time.Now(),
+		subscriptionLimiter: util.NewFixedLimiter(int64(conf.VisitorSubscriptionLimit)),
+		requestLimiter:      nil, // Set in resetLimiters
+		messagesLimiter:     nil, // Set in resetLimiters, may be nil
+		emailsLimiter:       nil, // Set in resetLimiters
+		bandwidthLimiter:    nil, // Set in resetLimiters
+		accountLimiter:      nil, // Set in resetLimiters, may be nil
 	}
+	v.resetLimiters()
+	return v
 }
 
 func (v *visitor) String() string {
 	v.mu.Lock()
 	defer v.mu.Unlock()
+	return v.stringNoLock()
+}
+
+func (v *visitor) stringNoLock() string {
 	if v.user != nil && v.user.Billing.StripeCustomerID != "" {
 		return fmt.Sprintf("%s/%s/%s", v.ip.String(), v.user.ID, v.user.Billing.StripeCustomerID)
 	} else if v.user != nil {
@@ -179,6 +172,13 @@ func (v *visitor) SubscriptionAllowed() error {
 	return nil
 }
 
+func (v *visitor) AccountCreationAllowed() error {
+	if v.accountLimiter != nil && !v.accountLimiter.Allow() {
+		return errVisitorLimitReached
+	}
+	return nil
+}
+
 func (v *visitor) RemoveSubscription() {
 	v.mu.Lock()
 	defer v.mu.Unlock()
@@ -235,7 +235,35 @@ func (v *visitor) ResetStats() {
 func (v *visitor) SetUser(u *user.User) {
 	v.mu.Lock()
 	defer v.mu.Unlock()
+	shouldResetLimiters := v.user.TierID() != u.TierID() // TierID works with nil receiver
 	v.user = u
+	if shouldResetLimiters {
+		v.resetLimiters()
+	}
+}
+
+func (v *visitor) resetLimiters() {
+	log.Info("%s Resetting limiters for visitor", v.stringNoLock())
+	var messagesLimiter, bandwidthLimiter util.Limiter
+	var requestLimiter, emailsLimiter, accountLimiter *rate.Limiter
+	if v.user != nil && v.user.Tier != nil {
+		requestLimiter = rate.NewLimiter(dailyLimitToRate(v.user.Tier.MessagesLimit), v.config.VisitorRequestLimitBurst)
+		messagesLimiter = util.NewFixedLimiter(v.user.Tier.MessagesLimit)
+		emailsLimiter = rate.NewLimiter(dailyLimitToRate(v.user.Tier.EmailsLimit), v.config.VisitorEmailLimitBurst)
+		bandwidthLimiter = util.NewBytesLimiter(int(v.user.Tier.AttachmentBandwidthLimit), 24*time.Hour)
+		accountLimiter = nil // A logged-in user cannot create an account
+	} else {
+		requestLimiter = rate.NewLimiter(rate.Every(v.config.VisitorRequestLimitReplenish), v.config.VisitorRequestLimitBurst)
+		messagesLimiter = nil // Message limit is governed by the requestLimiter
+		emailsLimiter = rate.NewLimiter(rate.Every(v.config.VisitorEmailLimitReplenish), v.config.VisitorEmailLimitBurst)
+		bandwidthLimiter = util.NewBytesLimiter(int(v.config.VisitorAttachmentDailyBandwidthLimit), 24*time.Hour)
+		accountLimiter = rate.NewLimiter(rate.Every(v.config.VisitorAccountCreationLimitReplenish), v.config.VisitorAccountCreationLimitBurst)
+	}
+	v.requestLimiter = requestLimiter
+	v.messagesLimiter = messagesLimiter
+	v.emailsLimiter = emailsLimiter
+	v.bandwidthLimiter = bandwidthLimiter
+	v.accountLimiter = accountLimiter
 }
 
 // MaybeUserID returns the user ID of the visitor (if any). If this is an anonymous visitor,

+ 11 - 8
user/manager.go

@@ -110,26 +110,26 @@ const (
 	`
 
 	selectUserByIDQuery = `
-		SELECT u.id, u.user, u.pass, u.role, u.prefs, u.sync_topic, u.stats_messages, u.stats_emails, u.stripe_customer_id, u.stripe_subscription_id, u.stripe_subscription_status, u.stripe_subscription_paid_until, u.stripe_subscription_cancel_at, deleted, t.code, t.name, t.messages_limit, t.messages_expiry_duration, t.emails_limit, t.reservations_limit, t.attachment_file_size_limit, t.attachment_total_size_limit, t.attachment_expiry_duration, t.attachment_bandwidth_limit, t.stripe_price_id
+		SELECT u.id, u.user, u.pass, u.role, u.prefs, u.sync_topic, u.stats_messages, u.stats_emails, u.stripe_customer_id, u.stripe_subscription_id, u.stripe_subscription_status, u.stripe_subscription_paid_until, u.stripe_subscription_cancel_at, deleted, t.id, t.code, t.name, t.messages_limit, t.messages_expiry_duration, t.emails_limit, t.reservations_limit, t.attachment_file_size_limit, t.attachment_total_size_limit, t.attachment_expiry_duration, t.attachment_bandwidth_limit, t.stripe_price_id
 		FROM user u
 		LEFT JOIN tier t on t.id = u.tier_id
 		WHERE u.id = ?		
 	`
 	selectUserByNameQuery = `
-		SELECT u.id, u.user, u.pass, u.role, u.prefs, u.sync_topic, u.stats_messages, u.stats_emails, u.stripe_customer_id, u.stripe_subscription_id, u.stripe_subscription_status, u.stripe_subscription_paid_until, u.stripe_subscription_cancel_at, deleted, t.code, t.name, t.messages_limit, t.messages_expiry_duration, t.emails_limit, t.reservations_limit, t.attachment_file_size_limit, t.attachment_total_size_limit, t.attachment_expiry_duration, t.attachment_bandwidth_limit, t.stripe_price_id
+		SELECT u.id, u.user, u.pass, u.role, u.prefs, u.sync_topic, u.stats_messages, u.stats_emails, u.stripe_customer_id, u.stripe_subscription_id, u.stripe_subscription_status, u.stripe_subscription_paid_until, u.stripe_subscription_cancel_at, deleted, t.id, t.code, t.name, t.messages_limit, t.messages_expiry_duration, t.emails_limit, t.reservations_limit, t.attachment_file_size_limit, t.attachment_total_size_limit, t.attachment_expiry_duration, t.attachment_bandwidth_limit, t.stripe_price_id
 		FROM user u
 		LEFT JOIN tier t on t.id = u.tier_id
 		WHERE user = ?
 	`
 	selectUserByTokenQuery = `
-		SELECT u.id, u.user, u.pass, u.role, u.prefs, u.sync_topic, u.stats_messages, u.stats_emails, u.stripe_customer_id, u.stripe_subscription_id, u.stripe_subscription_status, u.stripe_subscription_paid_until, u.stripe_subscription_cancel_at, deleted, t.code, t.name, t.messages_limit, t.messages_expiry_duration, t.emails_limit, t.reservations_limit, t.attachment_file_size_limit, t.attachment_total_size_limit, t.attachment_expiry_duration, t.attachment_bandwidth_limit, t.stripe_price_id
+		SELECT u.id, u.user, u.pass, u.role, u.prefs, u.sync_topic, u.stats_messages, u.stats_emails, u.stripe_customer_id, u.stripe_subscription_id, u.stripe_subscription_status, u.stripe_subscription_paid_until, u.stripe_subscription_cancel_at, deleted, t.id, t.code, t.name, t.messages_limit, t.messages_expiry_duration, t.emails_limit, t.reservations_limit, t.attachment_file_size_limit, t.attachment_total_size_limit, t.attachment_expiry_duration, t.attachment_bandwidth_limit, t.stripe_price_id
 		FROM user u
 		JOIN user_token t on u.id = t.user_id
 		LEFT JOIN tier t on t.id = u.tier_id
 		WHERE t.token = ? AND t.expires >= ?
 	`
 	selectUserByStripeCustomerIDQuery = `
-		SELECT u.id, u.user, u.pass, u.role, u.prefs, u.sync_topic, u.stats_messages, u.stats_emails, u.stripe_customer_id, u.stripe_subscription_id, u.stripe_subscription_status, u.stripe_subscription_paid_until, u.stripe_subscription_cancel_at, deleted, t.code, t.name, t.messages_limit, t.messages_expiry_duration, t.emails_limit, t.reservations_limit, t.attachment_file_size_limit, t.attachment_total_size_limit, t.attachment_expiry_duration, t.attachment_bandwidth_limit, t.stripe_price_id
+		SELECT u.id, u.user, u.pass, u.role, u.prefs, u.sync_topic, u.stats_messages, u.stats_emails, u.stripe_customer_id, u.stripe_subscription_id, u.stripe_subscription_status, u.stripe_subscription_paid_until, u.stripe_subscription_cancel_at, deleted, t.id, t.code, t.name, t.messages_limit, t.messages_expiry_duration, t.emails_limit, t.reservations_limit, t.attachment_file_size_limit, t.attachment_total_size_limit, t.attachment_expiry_duration, t.attachment_bandwidth_limit, t.stripe_price_id
 		FROM user u
 		LEFT JOIN tier t on t.id = u.tier_id
 		WHERE u.stripe_customer_id = ?
@@ -669,13 +669,13 @@ func (a *Manager) userByToken(token string) (*User, error) {
 func (a *Manager) readUser(rows *sql.Rows) (*User, error) {
 	defer rows.Close()
 	var id, username, hash, role, prefs, syncTopic string
-	var stripeCustomerID, stripeSubscriptionID, stripeSubscriptionStatus, stripePriceID, tierCode, tierName sql.NullString
+	var stripeCustomerID, stripeSubscriptionID, stripeSubscriptionStatus, stripePriceID, tierID, tierCode, tierName sql.NullString
 	var messages, emails int64
 	var messagesLimit, messagesExpiryDuration, emailsLimit, reservationsLimit, attachmentFileSizeLimit, attachmentTotalSizeLimit, attachmentExpiryDuration, attachmentBandwidthLimit, stripeSubscriptionPaidUntil, stripeSubscriptionCancelAt, deleted sql.NullInt64
 	if !rows.Next() {
 		return nil, ErrUserNotFound
 	}
-	if err := rows.Scan(&id, &username, &hash, &role, &prefs, &syncTopic, &messages, &emails, &stripeCustomerID, &stripeSubscriptionID, &stripeSubscriptionStatus, &stripeSubscriptionPaidUntil, &stripeSubscriptionCancelAt, &deleted, &tierCode, &tierName, &messagesLimit, &messagesExpiryDuration, &emailsLimit, &reservationsLimit, &attachmentFileSizeLimit, &attachmentTotalSizeLimit, &attachmentExpiryDuration, &attachmentBandwidthLimit, &stripePriceID); err != nil {
+	if err := rows.Scan(&id, &username, &hash, &role, &prefs, &syncTopic, &messages, &emails, &stripeCustomerID, &stripeSubscriptionID, &stripeSubscriptionStatus, &stripeSubscriptionPaidUntil, &stripeSubscriptionCancelAt, &deleted, &tierID, &tierCode, &tierName, &messagesLimit, &messagesExpiryDuration, &emailsLimit, &reservationsLimit, &attachmentFileSizeLimit, &attachmentTotalSizeLimit, &attachmentExpiryDuration, &attachmentBandwidthLimit, &stripePriceID); err != nil {
 		return nil, err
 	} else if err := rows.Err(); err != nil {
 		return nil, err
@@ -706,6 +706,7 @@ func (a *Manager) readUser(rows *sql.Rows) (*User, error) {
 	if tierCode.Valid {
 		// See readTier() when this is changed!
 		user.Tier = &Tier{
+			ID:                       tierID.String,
 			Code:                     tierCode.String,
 			Name:                     tierName.String,
 			MessagesLimit:            messagesLimit.Int64,
@@ -995,8 +996,10 @@ func (a *Manager) DefaultAccess() Permission {
 
 // CreateTier creates a new tier in the database
 func (a *Manager) CreateTier(tier *Tier) error {
-	tierID := util.RandomStringPrefix(tierIDPrefix, tierIDLength)
-	if _, err := a.db.Exec(insertTierQuery, tierID, tier.Code, tier.Name, tier.MessagesLimit, int64(tier.MessagesExpiryDuration.Seconds()), tier.EmailsLimit, tier.ReservationsLimit, tier.AttachmentFileSizeLimit, tier.AttachmentTotalSizeLimit, int64(tier.AttachmentExpiryDuration.Seconds()), tier.AttachmentBandwidthLimit, tier.StripePriceID); err != nil {
+	if tier.ID == "" {
+		tier.ID = util.RandomStringPrefix(tierIDPrefix, tierIDLength)
+	}
+	if _, err := a.db.Exec(insertTierQuery, tier.ID, tier.Code, tier.Name, tier.MessagesLimit, int64(tier.MessagesExpiryDuration.Seconds()), tier.EmailsLimit, tier.ReservationsLimit, tier.AttachmentFileSizeLimit, tier.AttachmentTotalSizeLimit, int64(tier.AttachmentExpiryDuration.Seconds()), tier.AttachmentBandwidthLimit, tier.StripePriceID); err != nil {
 		return err
 	}
 	return nil

+ 9 - 0
user/types.go

@@ -23,6 +23,15 @@ type User struct {
 	Deleted   bool
 }
 
+// TierID returns the ID of the User.Tier, or an empty string if the user has no tier,
+// or if the user itself is nil.
+func (u *User) TierID() string {
+	if u == nil || u.Tier == nil {
+		return ""
+	}
+	return u.Tier.ID
+}
+
 // Auther is an interface for authentication and authorization
 type Auther interface {
 	// Authenticate checks username and password and returns a user if correct. The method