Bladeren bron

No more v.user races

binwiederhier 3 jaren geleden
bovenliggende
commit
92d563371c
5 gewijzigde bestanden met toevoegingen van 87 en 77 verwijderingen
  1. 0 1
      server/server.go
  2. 61 53
      server/server_account.go
  3. 2 5
      server/server_middleware.go
  4. 22 18
      server/server_payments.go
  5. 2 0
      user/manager.go

+ 0 - 1
server/server.go

@@ -39,7 +39,6 @@ import (
 - HIGH Rate limiting: Sensitive endpoints (account/login/change-password/...)
 - HIGH Stripe payment methods
 - MEDIUM: Test new token endpoints & never-expiring token
-- MEDIUM: Races with v.user (see publishSyncEventAsync test)
 - MEDIUM: Test that anonymous user and user without tier are the same visitor
 - MEDIUM: Make sure account endpoints make sense for admins
 - MEDIUM: Reservation (UI): Show "This topic is reserved" error message when trying to reserve a reserved topic (Thorben)

+ 61 - 53
server/server_account.go

@@ -19,11 +19,12 @@ const (
 )
 
 func (s *Server) handleAccountCreate(w http.ResponseWriter, r *http.Request, v *visitor) error {
-	admin := v.user != nil && v.user.Role == user.RoleAdmin
+	u := v.User()
+	admin := u != nil && u.Role == user.RoleAdmin
 	if !admin {
 		if !s.config.EnableSignup {
 			return errHTTPBadRequestSignupNotEnabled
-		} else if v.user != nil {
+		} else if u != nil {
 			return errHTTPUnauthorized // Cannot create account from user context
 		}
 		if !v.AccountCreationAllowed() {
@@ -150,20 +151,21 @@ func (s *Server) handleAccountDelete(w http.ResponseWriter, r *http.Request, v *
 	} else if req.Password == "" {
 		return errHTTPBadRequest
 	}
-	if _, err := s.userManager.Authenticate(v.user.Name, req.Password); err != nil {
+	u := v.User()
+	if _, err := s.userManager.Authenticate(u.Name, req.Password); err != nil {
 		return errHTTPBadRequestIncorrectPasswordConfirmation
 	}
-	if v.user.Billing.StripeSubscriptionID != "" {
-		log.Info("%s Canceling billing subscription %s", logHTTPPrefix(v, r), v.user.Billing.StripeSubscriptionID)
-		if _, err := s.stripe.CancelSubscription(v.user.Billing.StripeSubscriptionID); err != nil {
+	if u.Billing.StripeSubscriptionID != "" {
+		log.Info("%s Canceling billing subscription %s", logHTTPPrefix(v, r), u.Billing.StripeSubscriptionID)
+		if _, err := s.stripe.CancelSubscription(u.Billing.StripeSubscriptionID); err != nil {
 			return err
 		}
 	}
-	if err := s.maybeRemoveMessagesAndExcessReservations(logHTTPPrefix(v, r), v.user, 0); err != nil {
+	if err := s.maybeRemoveMessagesAndExcessReservations(logHTTPPrefix(v, r), u, 0); err != nil {
 		return err
 	}
-	log.Info("%s Marking user %s as deleted", logHTTPPrefix(v, r), v.user.Name)
-	if err := s.userManager.MarkUserRemoved(v.user); err != nil {
+	log.Info("%s Marking user %s as deleted", logHTTPPrefix(v, r), u.Name)
+	if err := s.userManager.MarkUserRemoved(u); err != nil {
 		return err
 	}
 	return s.writeJSON(w, newSuccessResponse())
@@ -176,10 +178,11 @@ func (s *Server) handleAccountPasswordChange(w http.ResponseWriter, r *http.Requ
 	} else if req.Password == "" || req.NewPassword == "" {
 		return errHTTPBadRequest
 	}
-	if _, err := s.userManager.Authenticate(v.user.Name, req.Password); err != nil {
+	u := v.User()
+	if _, err := s.userManager.Authenticate(u.Name, req.Password); err != nil {
 		return errHTTPBadRequestIncorrectPasswordConfirmation
 	}
-	if err := s.userManager.ChangePassword(v.user.Name, req.NewPassword); err != nil {
+	if err := s.userManager.ChangePassword(u.Name, req.NewPassword); err != nil {
 		return err
 	}
 	return s.writeJSON(w, newSuccessResponse())
@@ -267,10 +270,11 @@ func (s *Server) handleAccountSettingsChange(w http.ResponseWriter, r *http.Requ
 	if err != nil {
 		return err
 	}
-	if v.user.Prefs == nil {
-		v.user.Prefs = &user.Prefs{}
+	u := v.User()
+	if u.Prefs == nil {
+		u.Prefs = &user.Prefs{}
 	}
-	prefs := v.user.Prefs
+	prefs := u.Prefs
 	if newPrefs.Language != nil {
 		prefs.Language = newPrefs.Language
 	}
@@ -288,7 +292,7 @@ func (s *Server) handleAccountSettingsChange(w http.ResponseWriter, r *http.Requ
 			prefs.Notification.MinPriority = newPrefs.Notification.MinPriority
 		}
 	}
-	if err := s.userManager.ChangeSettings(v.user); err != nil {
+	if err := s.userManager.ChangeSettings(u); err != nil {
 		return err
 	}
 	return s.writeJSON(w, newSuccessResponse())
@@ -299,11 +303,12 @@ func (s *Server) handleAccountSubscriptionAdd(w http.ResponseWriter, r *http.Req
 	if err != nil {
 		return err
 	}
-	if v.user.Prefs == nil {
-		v.user.Prefs = &user.Prefs{}
+	u := v.User()
+	if u.Prefs == nil {
+		u.Prefs = &user.Prefs{}
 	}
 	newSubscription.ID = "" // Client cannot set ID
-	for _, subscription := range v.user.Prefs.Subscriptions {
+	for _, subscription := range u.Prefs.Subscriptions {
 		if newSubscription.BaseURL == subscription.BaseURL && newSubscription.Topic == subscription.Topic {
 			newSubscription = subscription
 			break
@@ -311,8 +316,8 @@ func (s *Server) handleAccountSubscriptionAdd(w http.ResponseWriter, r *http.Req
 	}
 	if newSubscription.ID == "" {
 		newSubscription.ID = util.RandomStringPrefix(subscriptionIDPrefix, subscriptionIDLength)
-		v.user.Prefs.Subscriptions = append(v.user.Prefs.Subscriptions, newSubscription)
-		if err := s.userManager.ChangeSettings(v.user); err != nil {
+		u.Prefs.Subscriptions = append(u.Prefs.Subscriptions, newSubscription)
+		if err := s.userManager.ChangeSettings(u); err != nil {
 			return err
 		}
 	}
@@ -329,11 +334,12 @@ func (s *Server) handleAccountSubscriptionChange(w http.ResponseWriter, r *http.
 	if err != nil {
 		return err
 	}
-	if v.user.Prefs == nil || v.user.Prefs.Subscriptions == nil {
+	u := v.User()
+	if u.Prefs == nil || u.Prefs.Subscriptions == nil {
 		return errHTTPNotFound
 	}
 	var subscription *user.Subscription
-	for _, sub := range v.user.Prefs.Subscriptions {
+	for _, sub := range u.Prefs.Subscriptions {
 		if sub.ID == subscriptionID {
 			sub.DisplayName = updatedSubscription.DisplayName
 			subscription = sub
@@ -343,7 +349,7 @@ func (s *Server) handleAccountSubscriptionChange(w http.ResponseWriter, r *http.
 	if subscription == nil {
 		return errHTTPNotFound
 	}
-	if err := s.userManager.ChangeSettings(v.user); err != nil {
+	if err := s.userManager.ChangeSettings(u); err != nil {
 		return err
 	}
 	return s.writeJSON(w, subscription)
@@ -355,18 +361,19 @@ func (s *Server) handleAccountSubscriptionDelete(w http.ResponseWriter, r *http.
 		return errHTTPInternalErrorInvalidPath
 	}
 	subscriptionID := matches[1]
-	if v.user.Prefs == nil || v.user.Prefs.Subscriptions == nil {
+	u := v.User()
+	if u.Prefs == nil || u.Prefs.Subscriptions == nil {
 		return nil
 	}
 	newSubscriptions := make([]*user.Subscription, 0)
-	for _, subscription := range v.user.Prefs.Subscriptions {
+	for _, subscription := range u.Prefs.Subscriptions {
 		if subscription.ID != subscriptionID {
 			newSubscriptions = append(newSubscriptions, subscription)
 		}
 	}
-	if len(newSubscriptions) < len(v.user.Prefs.Subscriptions) {
-		v.user.Prefs.Subscriptions = newSubscriptions
-		if err := s.userManager.ChangeSettings(v.user); err != nil {
+	if len(newSubscriptions) < len(u.Prefs.Subscriptions) {
+		u.Prefs.Subscriptions = newSubscriptions
+		if err := s.userManager.ChangeSettings(u); err != nil {
 			return err
 		}
 	}
@@ -374,7 +381,8 @@ func (s *Server) handleAccountSubscriptionDelete(w http.ResponseWriter, r *http.
 }
 
 func (s *Server) handleAccountReservationAdd(w http.ResponseWriter, r *http.Request, v *visitor) error {
-	if v.user != nil && v.user.Role == user.RoleAdmin {
+	u := v.User()
+	if u != nil && u.Role == user.RoleAdmin {
 		return errHTTPBadRequestMakesNoSenseForAdmin
 	}
 	req, err := readJSONWithLimit[apiAccountReservationRequest](r.Body, jsonBodyBytesLimit, false)
@@ -388,27 +396,27 @@ func (s *Server) handleAccountReservationAdd(w http.ResponseWriter, r *http.Requ
 	if err != nil {
 		return errHTTPBadRequestPermissionInvalid
 	}
-	if v.user.Tier == nil {
+	if u.Tier == nil {
 		return errHTTPUnauthorized
 	}
 	// CHeck if we are allowed to reserve this topic
-	if err := s.userManager.CheckAllowAccess(v.user.Name, req.Topic); err != nil {
+	if err := s.userManager.CheckAllowAccess(u.Name, req.Topic); err != nil {
 		return errHTTPConflictTopicReserved
 	}
-	hasReservation, err := s.userManager.HasReservation(v.user.Name, req.Topic)
+	hasReservation, err := s.userManager.HasReservation(u.Name, req.Topic)
 	if err != nil {
 		return err
 	}
 	if !hasReservation {
-		reservations, err := s.userManager.ReservationsCount(v.user.Name)
+		reservations, err := s.userManager.ReservationsCount(u.Name)
 		if err != nil {
 			return err
-		} else if reservations >= v.user.Tier.ReservationLimit {
+		} else if reservations >= u.Tier.ReservationLimit {
 			return errHTTPTooManyRequestsLimitReservations
 		}
 	}
 	// Actually add the reservation
-	if err := s.userManager.AddReservation(v.user.Name, req.Topic, everyone); err != nil {
+	if err := s.userManager.AddReservation(u.Name, req.Topic, everyone); err != nil {
 		return err
 	}
 	// Kill existing subscribers
@@ -416,7 +424,7 @@ func (s *Server) handleAccountReservationAdd(w http.ResponseWriter, r *http.Requ
 	if err != nil {
 		return err
 	}
-	t.CancelSubscribers(v.user.ID)
+	t.CancelSubscribers(u.ID)
 	return s.writeJSON(w, newSuccessResponse())
 }
 
@@ -429,13 +437,14 @@ func (s *Server) handleAccountReservationDelete(w http.ResponseWriter, r *http.R
 	if !topicRegex.MatchString(topic) {
 		return errHTTPBadRequestTopicInvalid
 	}
-	authorized, err := s.userManager.HasReservation(v.user.Name, topic)
+	u := v.User()
+	authorized, err := s.userManager.HasReservation(u.Name, topic)
 	if err != nil {
 		return err
 	} else if !authorized {
 		return errHTTPUnauthorized
 	}
-	if err := s.userManager.RemoveReservations(v.user.Name, topic); err != nil {
+	if err := s.userManager.RemoveReservations(u.Name, topic); err != nil {
 		return err
 	}
 	return s.writeJSON(w, newSuccessResponse())
@@ -465,12 +474,23 @@ func (s *Server) maybeRemoveMessagesAndExcessReservations(logPrefix string, u *u
 	return nil
 }
 
+// publishSyncEventAsync kicks of a Go routine to publish a sync message to the user's sync topic
+func (s *Server) publishSyncEventAsync(v *visitor) {
+	go func() {
+		if err := s.publishSyncEvent(v); err != nil {
+			log.Trace("%s Error publishing to user's sync topic: %s", v.String(), err.Error())
+		}
+	}()
+}
+
+// publishSyncEvent publishes a sync message to the user's sync topic
 func (s *Server) publishSyncEvent(v *visitor) error {
-	if v.user == nil || v.user.SyncTopic == "" {
+	u := v.User()
+	if u == nil || u.SyncTopic == "" {
 		return nil
 	}
-	log.Trace("Publishing sync event to user %s's sync topic %s", v.user.Name, v.user.SyncTopic)
-	syncTopic, err := s.topicFromID(v.user.SyncTopic)
+	log.Trace("Publishing sync event to user %s's sync topic %s", u.Name, u.SyncTopic)
+	syncTopic, err := s.topicFromID(u.SyncTopic)
 	if err != nil {
 		return err
 	}
@@ -484,15 +504,3 @@ func (s *Server) publishSyncEvent(v *visitor) error {
 	}
 	return nil
 }
-
-func (s *Server) publishSyncEventAsync(v *visitor) {
-	go func() {
-		u := v.User()
-		if u == nil || u.SyncTopic == "" {
-			return
-		}
-		if err := s.publishSyncEvent(v); err != nil {
-			log.Trace("Error publishing to user %s's sync topic %s: %s", u.Name, u.SyncTopic, err.Error())
-		}
-	}()
-}

+ 2 - 5
server/server_middleware.go

@@ -24,7 +24,7 @@ func (s *Server) ensureUserManager(next handleFunc) handleFunc {
 
 func (s *Server) ensureUser(next handleFunc) handleFunc {
 	return s.ensureUserManager(func(w http.ResponseWriter, r *http.Request, v *visitor) error {
-		if v.user == nil {
+		if v.User() == nil {
 			return errHTTPUnauthorized
 		}
 		return next(w, r, v)
@@ -42,7 +42,7 @@ func (s *Server) ensurePaymentsEnabled(next handleFunc) handleFunc {
 
 func (s *Server) ensureStripeCustomer(next handleFunc) handleFunc {
 	return s.ensureUser(func(w http.ResponseWriter, r *http.Request, v *visitor) error {
-		if v.user.Billing.StripeCustomerID == "" {
+		if v.User().Billing.StripeCustomerID == "" {
 			return errHTTPBadRequestNotAPaidUser
 		}
 		return next(w, r, v)
@@ -51,9 +51,6 @@ func (s *Server) ensureStripeCustomer(next handleFunc) handleFunc {
 
 func (s *Server) withAccountSync(next handleFunc) handleFunc {
 	return func(w http.ResponseWriter, r *http.Request, v *visitor) error {
-		if v.user == nil {
-			return next(w, r, v)
-		}
 		err := next(w, r, v)
 		if err == nil {
 			s.publishSyncEventAsync(v)

+ 22 - 18
server/server_payments.go

@@ -54,7 +54,7 @@ var (
 )
 
 // handleBillingTiersGet returns all available paid tiers, and the free tier. This is to populate the upgrade dialog
-// in the UI. Note that this endpoint does NOT have a user context (no v.user!).
+// in the UI. Note that this endpoint does NOT have a user context (no u!).
 func (s *Server) handleBillingTiersGet(w http.ResponseWriter, _ *http.Request, _ *visitor) error {
 	tiers, err := s.userManager.Tiers()
 	if err != nil {
@@ -107,7 +107,8 @@ func (s *Server) handleBillingTiersGet(w http.ResponseWriter, _ *http.Request, _
 // handleAccountBillingSubscriptionCreate creates a Stripe checkout flow to create a user subscription. The tier
 // will be updated by a subsequent webhook from Stripe, once the subscription becomes active.
 func (s *Server) handleAccountBillingSubscriptionCreate(w http.ResponseWriter, r *http.Request, v *visitor) error {
-	if v.user.Billing.StripeSubscriptionID != "" {
+	u := v.User()
+	if u.Billing.StripeSubscriptionID != "" {
 		return errHTTPBadRequestBillingSubscriptionExists
 	}
 	req, err := readJSONWithLimit[apiAccountBillingSubscriptionChangeRequest](r.Body, jsonBodyBytesLimit, false)
@@ -122,9 +123,9 @@ func (s *Server) handleAccountBillingSubscriptionCreate(w http.ResponseWriter, r
 	}
 	log.Info("%s Creating Stripe checkout flow", logHTTPPrefix(v, r))
 	var stripeCustomerID *string
-	if v.user.Billing.StripeCustomerID != "" {
-		stripeCustomerID = &v.user.Billing.StripeCustomerID
-		stripeCustomer, err := s.stripe.GetCustomer(v.user.Billing.StripeCustomerID)
+	if u.Billing.StripeCustomerID != "" {
+		stripeCustomerID = &u.Billing.StripeCustomerID
+		stripeCustomer, err := s.stripe.GetCustomer(u.Billing.StripeCustomerID)
 		if err != nil {
 			return err
 		} else if stripeCustomer.Subscriptions != nil && len(stripeCustomer.Subscriptions.Data) > 0 {
@@ -134,7 +135,7 @@ func (s *Server) handleAccountBillingSubscriptionCreate(w http.ResponseWriter, r
 	successURL := s.config.BaseURL + apiAccountBillingSubscriptionCheckoutSuccessTemplate
 	params := &stripe.CheckoutSessionParams{
 		Customer:            stripeCustomerID, // A user may have previously deleted their subscription
-		ClientReferenceID:   &v.user.ID,
+		ClientReferenceID:   &u.ID,
 		SuccessURL:          &successURL,
 		Mode:                stripe.String(string(stripe.CheckoutSessionModeSubscription)),
 		AllowPromotionCodes: stripe.Bool(true),
@@ -146,7 +147,7 @@ func (s *Server) handleAccountBillingSubscriptionCreate(w http.ResponseWriter, r
 		},
 		Params: stripe.Params{
 			Metadata: map[string]string{
-				"user_id": v.user.ID,
+				"user_id": u.ID,
 			},
 		},
 	}
@@ -164,7 +165,7 @@ func (s *Server) handleAccountBillingSubscriptionCreate(w http.ResponseWriter, r
 // the session ID in the URL to retrieve the Stripe subscription and update the local database. This is the first
 // and only time we can map the local username with the Stripe customer ID.
 func (s *Server) handleAccountBillingSubscriptionCreateSuccess(w http.ResponseWriter, r *http.Request, v *visitor) error {
-	// We don't have a v.user in this endpoint, only a userManager!
+	// We don't have v.User() in this endpoint, only a userManager!
 	matches := apiAccountBillingSubscriptionCheckoutSuccessRegex.FindStringSubmatch(r.URL.Path)
 	if len(matches) != 2 {
 		return errHTTPInternalErrorInvalidPath
@@ -212,7 +213,8 @@ func (s *Server) handleAccountBillingSubscriptionCreateSuccess(w http.ResponseWr
 // handleAccountBillingSubscriptionUpdate updates an existing Stripe subscription to a new price, and updates
 // a user's tier accordingly. This endpoint only works if there is an existing subscription.
 func (s *Server) handleAccountBillingSubscriptionUpdate(w http.ResponseWriter, r *http.Request, v *visitor) error {
-	if v.user.Billing.StripeSubscriptionID == "" {
+	u := v.User()
+	if u.Billing.StripeSubscriptionID == "" {
 		return errNoBillingSubscription
 	}
 	req, err := readJSONWithLimit[apiAccountBillingSubscriptionChangeRequest](r.Body, jsonBodyBytesLimit, false)
@@ -223,8 +225,8 @@ func (s *Server) handleAccountBillingSubscriptionUpdate(w http.ResponseWriter, r
 	if err != nil {
 		return err
 	}
-	log.Info("%s Changing billing tier to %s (price %s) for subscription %s", logHTTPPrefix(v, r), tier.Code, tier.StripePriceID, v.user.Billing.StripeSubscriptionID)
-	sub, err := s.stripe.GetSubscription(v.user.Billing.StripeSubscriptionID)
+	log.Info("%s Changing billing tier to %s (price %s) for subscription %s", logHTTPPrefix(v, r), tier.Code, tier.StripePriceID, u.Billing.StripeSubscriptionID)
+	sub, err := s.stripe.GetSubscription(u.Billing.StripeSubscriptionID)
 	if err != nil {
 		return err
 	}
@@ -248,12 +250,13 @@ func (s *Server) handleAccountBillingSubscriptionUpdate(w http.ResponseWriter, r
 // handleAccountBillingSubscriptionDelete facilitates downgrading a paid user to a tier-less user,
 // and cancelling the Stripe subscription entirely
 func (s *Server) handleAccountBillingSubscriptionDelete(w http.ResponseWriter, r *http.Request, v *visitor) error {
-	log.Info("%s Deleting billing subscription %s", logHTTPPrefix(v, r), v.user.Billing.StripeSubscriptionID)
-	if v.user.Billing.StripeSubscriptionID != "" {
+	u := v.User()
+	log.Info("%s Deleting billing subscription %s", logHTTPPrefix(v, r), u.Billing.StripeSubscriptionID)
+	if u.Billing.StripeSubscriptionID != "" {
 		params := &stripe.SubscriptionParams{
 			CancelAtPeriodEnd: stripe.Bool(true),
 		}
-		_, err := s.stripe.UpdateSubscription(v.user.Billing.StripeSubscriptionID, params)
+		_, err := s.stripe.UpdateSubscription(u.Billing.StripeSubscriptionID, params)
 		if err != nil {
 			return err
 		}
@@ -264,12 +267,13 @@ func (s *Server) handleAccountBillingSubscriptionDelete(w http.ResponseWriter, r
 // handleAccountBillingPortalSessionCreate creates a session to the customer billing portal, and returns the
 // redirect URL. The billing portal allows customers to change their payment methods, and cancel the subscription.
 func (s *Server) handleAccountBillingPortalSessionCreate(w http.ResponseWriter, r *http.Request, v *visitor) error {
-	if v.user.Billing.StripeCustomerID == "" {
+	u := v.User()
+	if u.Billing.StripeCustomerID == "" {
 		return errHTTPBadRequestNotAPaidUser
 	}
 	log.Info("%s Creating billing portal session", logHTTPPrefix(v, r))
 	params := &stripe.BillingPortalSessionParams{
-		Customer:  stripe.String(v.user.Billing.StripeCustomerID),
+		Customer:  stripe.String(u.Billing.StripeCustomerID),
 		ReturnURL: stripe.String(s.config.BaseURL),
 	}
 	ps, err := s.stripe.NewPortalSession(params)
@@ -284,8 +288,8 @@ func (s *Server) handleAccountBillingPortalSessionCreate(w http.ResponseWriter,
 
 // handleAccountBillingWebhook handles incoming Stripe webhooks. It mainly keeps the local user database in sync
 // with the Stripe view of the world. This endpoint is authorized via the Stripe webhook secret. Note that the
-// visitor (v) in this endpoint is the Stripe API, so we don't have v.user available.
-func (s *Server) handleAccountBillingWebhook(w http.ResponseWriter, r *http.Request, _ *visitor) error {
+// visitor (v) in this endpoint is the Stripe API, so we don't have u available.
+func (s *Server) handleAccountBillingWebhook(_ http.ResponseWriter, r *http.Request, _ *visitor) error {
 	stripeSignature := r.Header.Get("Stripe-Signature")
 	if stripeSignature == "" {
 		return errHTTPBadRequestBillingRequestInvalid

+ 2 - 0
user/manager.go

@@ -30,6 +30,7 @@ const (
 	tokenMaxCount                   = 10 // Only keep this many tokens in the table per user
 )
 
+// Default constants that may be overridden by configs
 const (
 	DefaultUserStatsQueueWriterInterval = 33 * time.Second
 	DefaultUserPasswordBcryptCost       = 10
@@ -1195,6 +1196,7 @@ func (a *Manager) readTier(rows *sql.Rows) (*Tier, error) {
 	}, nil
 }
 
+// Close closes the underlying database
 func (a *Manager) Close() error {
 	return a.db.Close()
 }