|
|
@@ -2,7 +2,6 @@ package server
|
|
|
|
|
|
import (
|
|
|
"bytes"
|
|
|
- "encoding/json"
|
|
|
"errors"
|
|
|
"fmt"
|
|
|
"github.com/stripe/stripe-go/v74"
|
|
|
@@ -121,7 +120,13 @@ func (s *Server) handleAccountBillingSubscriptionCreate(w http.ResponseWriter, r
|
|
|
} else if tier.StripePriceID == "" {
|
|
|
return errNotAPaidTier
|
|
|
}
|
|
|
- log.Info("%s Creating Stripe checkout flow", logHTTPPrefix(v, r))
|
|
|
+ logvr(v, r).
|
|
|
+ Tag(tagPay).
|
|
|
+ Fields(map[string]any{
|
|
|
+ "tier": tier,
|
|
|
+ "stripe_price_id": tier.StripePriceID,
|
|
|
+ }).
|
|
|
+ Info("Creating Stripe checkout flow")
|
|
|
var stripeCustomerID *string
|
|
|
if u.Billing.StripeCustomerID != "" {
|
|
|
stripeCustomerID = &u.Billing.StripeCustomerID
|
|
|
@@ -190,6 +195,18 @@ func (s *Server) handleAccountBillingSubscriptionCreateSuccess(w http.ResponseWr
|
|
|
return err
|
|
|
}
|
|
|
v.SetUser(u)
|
|
|
+ logvr(v, r).
|
|
|
+ Tag(tagPay).
|
|
|
+ Fields(map[string]any{
|
|
|
+ "tier_id": tier.ID,
|
|
|
+ "tier_name": tier.Name,
|
|
|
+ "stripe_price_id": tier.StripePriceID,
|
|
|
+ "stripe_customer_id": sess.Customer.ID,
|
|
|
+ "stripe_subscription_id": sub.ID,
|
|
|
+ "stripe_subscription_status": string(sub.Status),
|
|
|
+ "stripe_subscription_paid_until": sub.CurrentPeriodEnd,
|
|
|
+ }).
|
|
|
+ Info("Stripe checkout flow succeeded, updating user tier and subscription")
|
|
|
customerParams := &stripe.CustomerParams{
|
|
|
Params: stripe.Params{
|
|
|
Metadata: map[string]string{
|
|
|
@@ -201,7 +218,7 @@ func (s *Server) handleAccountBillingSubscriptionCreateSuccess(w http.ResponseWr
|
|
|
if _, err := s.stripe.UpdateCustomer(sess.Customer.ID, customerParams); err != nil {
|
|
|
return err
|
|
|
}
|
|
|
- if err := s.updateSubscriptionAndTier(logHTTPPrefix(v, r), u, tier, sess.Customer.ID, sub.ID, string(sub.Status), sub.CurrentPeriodEnd, sub.CancelAt); err != nil {
|
|
|
+ if err := s.updateSubscriptionAndTier(r, v, u, tier, sess.Customer.ID, sub.ID, string(sub.Status), sub.CurrentPeriodEnd, sub.CancelAt); err != nil {
|
|
|
return err
|
|
|
}
|
|
|
http.Redirect(w, r, s.config.BaseURL+accountPath, http.StatusSeeOther)
|
|
|
@@ -223,7 +240,15 @@ func (s *Server) handleAccountBillingSubscriptionUpdate(w http.ResponseWriter, r
|
|
|
if err != nil {
|
|
|
return err
|
|
|
}
|
|
|
- log.Info("%s Changing billing tier to %s (price %s) for subscription %s", logHTTPPrefix(v, r), tier.Code, tier.StripePriceID, u.Billing.StripeSubscriptionID)
|
|
|
+ logvr(v, r).
|
|
|
+ Tag(tagPay).
|
|
|
+ Fields(map[string]any{
|
|
|
+ "new_tier_id": tier.ID,
|
|
|
+ "new_tier_name": tier.Name,
|
|
|
+ "new_tier_stripe_price_id": tier.StripePriceID,
|
|
|
+ // Other stripe_* fields filled by visitor context
|
|
|
+ }).
|
|
|
+ Info("Changing Stripe subscription and billing tier to %s/%s (price %s)", tier.ID, tier.Name, tier.StripePriceID)
|
|
|
sub, err := s.stripe.GetSubscription(u.Billing.StripeSubscriptionID)
|
|
|
if err != nil {
|
|
|
return err
|
|
|
@@ -250,8 +275,8 @@ func (s *Server) handleAccountBillingSubscriptionUpdate(w http.ResponseWriter, r
|
|
|
// handleAccountBillingSubscriptionDelete facilitates downgrading a paid user to a tier-less user,
|
|
|
// and cancelling the Stripe subscription entirely
|
|
|
func (s *Server) handleAccountBillingSubscriptionDelete(w http.ResponseWriter, r *http.Request, v *visitor) error {
|
|
|
+ logvr(v, r).Tag(tagPay).Info("Deleting Stripe subscription")
|
|
|
u := v.User()
|
|
|
- log.Info("%s Deleting billing subscription %s", logHTTPPrefix(v, r), u.Billing.StripeSubscriptionID)
|
|
|
if u.Billing.StripeSubscriptionID != "" {
|
|
|
params := &stripe.SubscriptionParams{
|
|
|
CancelAtPeriodEnd: stripe.Bool(true),
|
|
|
@@ -267,11 +292,11 @@ func (s *Server) handleAccountBillingSubscriptionDelete(w http.ResponseWriter, r
|
|
|
// handleAccountBillingPortalSessionCreate creates a session to the customer billing portal, and returns the
|
|
|
// redirect URL. The billing portal allows customers to change their payment methods, and cancel the subscription.
|
|
|
func (s *Server) handleAccountBillingPortalSessionCreate(w http.ResponseWriter, r *http.Request, v *visitor) error {
|
|
|
+ logvr(v, r).Tag(tagPay).Info("Creating Stripe billing portal session")
|
|
|
u := v.User()
|
|
|
if u.Billing.StripeCustomerID == "" {
|
|
|
return errHTTPBadRequestNotAPaidUser
|
|
|
}
|
|
|
- log.Info("%s Creating billing portal session", logHTTPPrefix(v, r))
|
|
|
params := &stripe.BillingPortalSessionParams{
|
|
|
Customer: stripe.String(u.Billing.StripeCustomerID),
|
|
|
ReturnURL: stripe.String(s.config.BaseURL),
|
|
|
@@ -289,7 +314,7 @@ func (s *Server) handleAccountBillingPortalSessionCreate(w http.ResponseWriter,
|
|
|
// handleAccountBillingWebhook handles incoming Stripe webhooks. It mainly keeps the local user database in sync
|
|
|
// with the Stripe view of the world. This endpoint is authorized via the Stripe webhook secret. Note that the
|
|
|
// visitor (v) in this endpoint is the Stripe API, so we don't have u available.
|
|
|
-func (s *Server) handleAccountBillingWebhook(_ http.ResponseWriter, r *http.Request, _ *visitor) error {
|
|
|
+func (s *Server) handleAccountBillingWebhook(_ http.ResponseWriter, r *http.Request, v *visitor) error {
|
|
|
stripeSignature := r.Header.Get("Stripe-Signature")
|
|
|
if stripeSignature == "" {
|
|
|
return errHTTPBadRequestBillingRequestInvalid
|
|
|
@@ -308,74 +333,105 @@ func (s *Server) handleAccountBillingWebhook(_ http.ResponseWriter, r *http.Requ
|
|
|
}
|
|
|
switch event.Type {
|
|
|
case "customer.subscription.updated":
|
|
|
- return s.handleAccountBillingWebhookSubscriptionUpdated(event.Data.Raw)
|
|
|
+ return s.handleAccountBillingWebhookSubscriptionUpdated(r, v, event)
|
|
|
case "customer.subscription.deleted":
|
|
|
- return s.handleAccountBillingWebhookSubscriptionDeleted(event.Data.Raw)
|
|
|
+ return s.handleAccountBillingWebhookSubscriptionDeleted(r, v, event)
|
|
|
default:
|
|
|
- log.Warn("STRIPE Unhandled webhook event %s received", event.Type)
|
|
|
+ logvr(v, r).
|
|
|
+ Tag(tagPay).
|
|
|
+ Field("stripe_webhook_type", event.Type).
|
|
|
+ Warn("Unhandled Stripe webhook event %s received", event.Type)
|
|
|
return nil
|
|
|
}
|
|
|
}
|
|
|
|
|
|
-func (s *Server) handleAccountBillingWebhookSubscriptionUpdated(event json.RawMessage) error {
|
|
|
- ev, err := util.UnmarshalJSON[apiStripeSubscriptionUpdatedEvent](io.NopCloser(bytes.NewReader(event)))
|
|
|
+func (s *Server) handleAccountBillingWebhookSubscriptionUpdated(r *http.Request, v *visitor, event stripe.Event) error {
|
|
|
+ ev, err := util.UnmarshalJSON[apiStripeSubscriptionUpdatedEvent](io.NopCloser(bytes.NewReader(event.Data.Raw)))
|
|
|
if err != nil {
|
|
|
return err
|
|
|
} else if ev.ID == "" || ev.Customer == "" || ev.Status == "" || ev.CurrentPeriodEnd == 0 || ev.Items == nil || len(ev.Items.Data) != 1 || ev.Items.Data[0].Price == nil || ev.Items.Data[0].Price.ID == "" {
|
|
|
return errHTTPBadRequestBillingRequestInvalid
|
|
|
}
|
|
|
subscriptionID, priceID := ev.ID, ev.Items.Data[0].Price.ID
|
|
|
- log.Info("%s Updating subscription to status %s, with price %s", logStripePrefix(ev.Customer, ev.ID), ev.Status, priceID)
|
|
|
+ logvr(v, r).
|
|
|
+ Tag(tagPay).
|
|
|
+ Fields(map[string]any{
|
|
|
+ "stripe_webhook_type": event.Type,
|
|
|
+ "stripe_customer_id": ev.Customer,
|
|
|
+ "stripe_subscription_id": ev.ID,
|
|
|
+ "stripe_subscription_status": ev.Status,
|
|
|
+ "stripe_subscription_paid_until": ev.CurrentPeriodEnd,
|
|
|
+ "stripe_subscription_cancel_at": ev.CancelAt,
|
|
|
+ "stripe_price_id": priceID,
|
|
|
+ }).
|
|
|
+ Info("Updating subscription to status %s, with price %s", ev.Status, priceID)
|
|
|
userFn := func() (*user.User, error) {
|
|
|
return s.userManager.UserByStripeCustomer(ev.Customer)
|
|
|
}
|
|
|
+ // We retry the user retrieval function, because during the Stripe checkout, there a race between the browser
|
|
|
+ // checkout success redirect (see handleAccountBillingSubscriptionCreateSuccess), and this webhook. The checkout
|
|
|
+ // success call is the one that updates the user with the Stripe customer ID.
|
|
|
u, err := util.Retry[user.User](userFn, retryUserDelays...)
|
|
|
if err != nil {
|
|
|
return err
|
|
|
}
|
|
|
+ v.SetUser(u)
|
|
|
tier, err := s.userManager.TierByStripePrice(priceID)
|
|
|
if err != nil {
|
|
|
return err
|
|
|
}
|
|
|
- if err := s.updateSubscriptionAndTier(logStripePrefix(ev.Customer, ev.ID), u, tier, ev.Customer, subscriptionID, ev.Status, ev.CurrentPeriodEnd, ev.CancelAt); err != nil {
|
|
|
+ if err := s.updateSubscriptionAndTier(r, v, u, tier, ev.Customer, subscriptionID, ev.Status, ev.CurrentPeriodEnd, ev.CancelAt); err != nil {
|
|
|
return err
|
|
|
}
|
|
|
s.publishSyncEventAsync(s.visitor(netip.IPv4Unspecified(), u))
|
|
|
return nil
|
|
|
}
|
|
|
|
|
|
-func (s *Server) handleAccountBillingWebhookSubscriptionDeleted(event json.RawMessage) error {
|
|
|
- ev, err := util.UnmarshalJSON[apiStripeSubscriptionDeletedEvent](io.NopCloser(bytes.NewReader(event)))
|
|
|
+func (s *Server) handleAccountBillingWebhookSubscriptionDeleted(r *http.Request, v *visitor, event stripe.Event) error {
|
|
|
+ ev, err := util.UnmarshalJSON[apiStripeSubscriptionDeletedEvent](io.NopCloser(bytes.NewReader(event.Data.Raw)))
|
|
|
if err != nil {
|
|
|
return err
|
|
|
} else if ev.Customer == "" {
|
|
|
return errHTTPBadRequestBillingRequestInvalid
|
|
|
}
|
|
|
- log.Info("%s Subscription deleted, downgrading to unpaid tier", logStripePrefix(ev.Customer, ev.ID))
|
|
|
u, err := s.userManager.UserByStripeCustomer(ev.Customer)
|
|
|
if err != nil {
|
|
|
return err
|
|
|
}
|
|
|
- if err := s.updateSubscriptionAndTier(logStripePrefix(ev.Customer, ev.ID), u, nil, ev.Customer, "", "", 0, 0); err != nil {
|
|
|
+ v.SetUser(u)
|
|
|
+ logvr(v, r).
|
|
|
+ Tag(tagPay).
|
|
|
+ Field("stripe_webhook_type", event.Type).
|
|
|
+ Info("Subscription deleted, downgrading to unpaid tier")
|
|
|
+ if err := s.updateSubscriptionAndTier(r, v, u, nil, ev.Customer, "", "", 0, 0); err != nil {
|
|
|
return err
|
|
|
}
|
|
|
s.publishSyncEventAsync(s.visitor(netip.IPv4Unspecified(), u))
|
|
|
return nil
|
|
|
}
|
|
|
|
|
|
-func (s *Server) updateSubscriptionAndTier(logPrefix string, u *user.User, tier *user.Tier, customerID, subscriptionID, status string, paidUntil, cancelAt int64) error {
|
|
|
+func (s *Server) updateSubscriptionAndTier(r *http.Request, v *visitor, u *user.User, tier *user.Tier, customerID, subscriptionID, status string, paidUntil, cancelAt int64) error {
|
|
|
reservationsLimit := visitorDefaultReservationsLimit
|
|
|
if tier != nil {
|
|
|
reservationsLimit = tier.ReservationLimit
|
|
|
}
|
|
|
- if err := s.maybeRemoveMessagesAndExcessReservations(logPrefix, u, reservationsLimit); err != nil {
|
|
|
+ if err := s.maybeRemoveMessagesAndExcessReservations(r, v, u, reservationsLimit); err != nil {
|
|
|
return err
|
|
|
}
|
|
|
- if tier == nil {
|
|
|
+ if tier == nil && u.Tier != nil {
|
|
|
+ logvr(v, r).Tag(tagPay).Info("Resetting tier for user %s", u.Name)
|
|
|
if err := s.userManager.ResetTier(u.Name); err != nil {
|
|
|
return err
|
|
|
}
|
|
|
- } else {
|
|
|
+ } else if tier != nil && u.TierID() != tier.ID {
|
|
|
+ logvr(v, r).
|
|
|
+ Tag(tagPay).
|
|
|
+ Fields(map[string]any{
|
|
|
+ "new_tier_id": tier.ID,
|
|
|
+ "new_tier_name": tier.Name,
|
|
|
+ "new_tier_stripe_price_id": tier.StripePriceID,
|
|
|
+ }).
|
|
|
+ Info("Changing tier to tier %s (%s) for user %s", tier.ID, tier.Name, u.Name)
|
|
|
if err := s.userManager.ChangeTier(u.Name, tier.Code); err != nil {
|
|
|
return err
|
|
|
}
|