Просмотр исходного кода

publishSyncEvent, Stripe endpoint changes

binwiederhier 3 лет назад
Родитель
Сommit
83de879894

+ 11 - 11
cmd/serve.go

@@ -80,8 +80,8 @@ var flagsServe = append(
 	altsrc.NewIntFlag(&cli.IntFlag{Name: "visitor-email-limit-burst", Aliases: []string{"visitor_email_limit_burst"}, EnvVars: []string{"NTFY_VISITOR_EMAIL_LIMIT_BURST"}, Value: server.DefaultVisitorEmailLimitBurst, Usage: "initial limit of e-mails per visitor"}),
 	altsrc.NewIntFlag(&cli.IntFlag{Name: "visitor-email-limit-burst", Aliases: []string{"visitor_email_limit_burst"}, EnvVars: []string{"NTFY_VISITOR_EMAIL_LIMIT_BURST"}, Value: server.DefaultVisitorEmailLimitBurst, Usage: "initial limit of e-mails per visitor"}),
 	altsrc.NewDurationFlag(&cli.DurationFlag{Name: "visitor-email-limit-replenish", Aliases: []string{"visitor_email_limit_replenish"}, EnvVars: []string{"NTFY_VISITOR_EMAIL_LIMIT_REPLENISH"}, Value: server.DefaultVisitorEmailLimitReplenish, Usage: "interval at which burst limit is replenished (one per x)"}),
 	altsrc.NewDurationFlag(&cli.DurationFlag{Name: "visitor-email-limit-replenish", Aliases: []string{"visitor_email_limit_replenish"}, EnvVars: []string{"NTFY_VISITOR_EMAIL_LIMIT_REPLENISH"}, Value: server.DefaultVisitorEmailLimitReplenish, Usage: "interval at which burst limit is replenished (one per x)"}),
 	altsrc.NewBoolFlag(&cli.BoolFlag{Name: "behind-proxy", Aliases: []string{"behind_proxy", "P"}, EnvVars: []string{"NTFY_BEHIND_PROXY"}, Value: false, Usage: "if set, use X-Forwarded-For header to determine visitor IP address (for rate limiting)"}),
 	altsrc.NewBoolFlag(&cli.BoolFlag{Name: "behind-proxy", Aliases: []string{"behind_proxy", "P"}, EnvVars: []string{"NTFY_BEHIND_PROXY"}, Value: false, Usage: "if set, use X-Forwarded-For header to determine visitor IP address (for rate limiting)"}),
-	altsrc.NewStringFlag(&cli.StringFlag{Name: "stripe-key", Aliases: []string{"stripe_key"}, EnvVars: []string{"NTFY_STRIPE_KEY"}, Value: "", Usage: "xxxxxxxxxxxxx"}),
-	altsrc.NewStringFlag(&cli.StringFlag{Name: "stripe-webhook-key", Aliases: []string{"stripe_webhook_key"}, EnvVars: []string{"NTFY_STRIPE_WEBHOOK_KEY"}, Value: "", Usage: "xxxxxxxxxxxx"}),
+	altsrc.NewStringFlag(&cli.StringFlag{Name: "stripe-secret-key", Aliases: []string{"stripe_secret_key"}, EnvVars: []string{"NTFY_STRIPE_SECRET_KEY"}, Value: "", Usage: "key used for the Stripe API communication, this enables payments"}),
+	altsrc.NewStringFlag(&cli.StringFlag{Name: "stripe-webhook-key", Aliases: []string{"stripe_webhook_key"}, EnvVars: []string{"NTFY_STRIPE_WEBHOOK_KEY"}, Value: "", Usage: "key required to validate the authenticity of incoming webhooks from Stripe"}),
 )
 )
 
 
 var cmdServe = &cli.Command{
 var cmdServe = &cli.Command{
@@ -153,7 +153,7 @@ func execServe(c *cli.Context) error {
 	visitorEmailLimitBurst := c.Int("visitor-email-limit-burst")
 	visitorEmailLimitBurst := c.Int("visitor-email-limit-burst")
 	visitorEmailLimitReplenish := c.Duration("visitor-email-limit-replenish")
 	visitorEmailLimitReplenish := c.Duration("visitor-email-limit-replenish")
 	behindProxy := c.Bool("behind-proxy")
 	behindProxy := c.Bool("behind-proxy")
-	stripeKey := c.String("stripe-key")
+	stripeSecretKey := c.String("stripe-secret-key")
 	stripeWebhookKey := c.String("stripe-webhook-key")
 	stripeWebhookKey := c.String("stripe-webhook-key")
 
 
 	// Check values
 	// Check values
@@ -191,17 +191,17 @@ func execServe(c *cli.Context) error {
 		return errors.New("if upstream-base-url is set, base-url must also be set")
 		return errors.New("if upstream-base-url is set, base-url must also be set")
 	} else if upstreamBaseURL != "" && baseURL != "" && baseURL == upstreamBaseURL {
 	} else if upstreamBaseURL != "" && baseURL != "" && baseURL == upstreamBaseURL {
 		return errors.New("base-url and upstream-base-url cannot be identical, you'll likely want to set upstream-base-url to https://ntfy.sh, see https://ntfy.sh/docs/config/#ios-instant-notifications")
 		return errors.New("base-url and upstream-base-url cannot be identical, you'll likely want to set upstream-base-url to https://ntfy.sh, see https://ntfy.sh/docs/config/#ios-instant-notifications")
-	} else if authFile == "" && (enableSignup || enableLogin || enableReservations || stripeKey != "") {
-		return errors.New("cannot set enable-signup, enable-login, enable-reserve-topics, or stripe-key if auth-file is not set")
+	} else if authFile == "" && (enableSignup || enableLogin || enableReservations || stripeSecretKey != "") {
+		return errors.New("cannot set enable-signup, enable-login, enable-reserve-topics, or stripe-secret-key if auth-file is not set")
 	} else if enableSignup && !enableLogin {
 	} else if enableSignup && !enableLogin {
 		return errors.New("cannot set enable-signup without also setting enable-login")
 		return errors.New("cannot set enable-signup without also setting enable-login")
-	} else if stripeKey != "" && (stripeWebhookKey == "" || baseURL == "") {
-		return errors.New("if stripe-key is set, stripe-webhook-key and base-url must also be set")
+	} else if stripeSecretKey != "" && (stripeWebhookKey == "" || baseURL == "") {
+		return errors.New("if stripe-secret-key is set, stripe-webhook-key and base-url must also be set")
 	}
 	}
 
 
 	webRootIsApp := webRoot == "app"
 	webRootIsApp := webRoot == "app"
 	enableWeb := webRoot != "disable"
 	enableWeb := webRoot != "disable"
-	enablePayments := stripeKey != ""
+	enablePayments := stripeSecretKey != ""
 
 
 	// Default auth permissions
 	// Default auth permissions
 	authDefault, err := user.ParsePermission(authDefaultAccess)
 	authDefault, err := user.ParsePermission(authDefaultAccess)
@@ -246,8 +246,8 @@ func execServe(c *cli.Context) error {
 	}
 	}
 
 
 	// Stripe things
 	// Stripe things
-	if stripeKey != "" {
-		stripe.Key = stripeKey
+	if stripeSecretKey != "" {
+		stripe.Key = stripeSecretKey
 	}
 	}
 
 
 	// Run server
 	// Run server
@@ -293,7 +293,7 @@ func execServe(c *cli.Context) error {
 	conf.VisitorEmailLimitBurst = visitorEmailLimitBurst
 	conf.VisitorEmailLimitBurst = visitorEmailLimitBurst
 	conf.VisitorEmailLimitReplenish = visitorEmailLimitReplenish
 	conf.VisitorEmailLimitReplenish = visitorEmailLimitReplenish
 	conf.BehindProxy = behindProxy
 	conf.BehindProxy = behindProxy
-	conf.StripeKey = stripeKey
+	conf.StripeSecretKey = stripeSecretKey
 	conf.StripeWebhookKey = stripeWebhookKey
 	conf.StripeWebhookKey = stripeWebhookKey
 	conf.EnableWeb = enableWeb
 	conf.EnableWeb = enableWeb
 	conf.EnableSignup = enableSignup
 	conf.EnableSignup = enableSignup

+ 1 - 1
server/config.go

@@ -110,7 +110,7 @@ type Config struct {
 	VisitorAccountCreateLimitReplenish   time.Duration
 	VisitorAccountCreateLimitReplenish   time.Duration
 	VisitorStatsResetTime                time.Time // Time of the day at which to reset visitor stats
 	VisitorStatsResetTime                time.Time // Time of the day at which to reset visitor stats
 	BehindProxy                          bool
 	BehindProxy                          bool
-	StripeKey                            string
+	StripeSecretKey                      string
 	StripeWebhookKey                     string
 	StripeWebhookKey                     string
 	EnableWeb                            bool
 	EnableWeb                            bool
 	EnableSignup                         bool // Enable creation of accounts via API and UI
 	EnableSignup                         bool // Enable creation of accounts via API and UI

+ 18 - 54
server/server.go

@@ -40,12 +40,10 @@ import (
 		- send dunning emails when overdue
 		- send dunning emails when overdue
 		- payment methods
 		- payment methods
 		- unmarshal to stripe.Subscription instead of gjson
 		- unmarshal to stripe.Subscription instead of gjson
-		- Make ResetTier reset the stripe fields
 		- delete subscription when account deleted
 		- delete subscription when account deleted
-		- remove tier.paid
 		- add tier.visible
 		- add tier.visible
 		- fix tier selection boxes
 		- fix tier selection boxes
-		- account sync after switching tiers
+		- delete messages + reserved topics on ResetTier
 
 
 		Limits & rate limiting:
 		Limits & rate limiting:
 			users without tier: should the stats be persisted? are they meaningful?
 			users without tier: should the stats be persisted? are they meaningful?
@@ -360,7 +358,7 @@ func (s *Server) handleInternal(w http.ResponseWriter, r *http.Request, v *visit
 	} else if r.Method == http.MethodGet && r.URL.Path == accountPath {
 	} else if r.Method == http.MethodGet && r.URL.Path == accountPath {
 		return s.handleAccountGet(w, r, v) // Allowed by anonymous
 		return s.handleAccountGet(w, r, v) // Allowed by anonymous
 	} else if r.Method == http.MethodDelete && r.URL.Path == accountPath {
 	} else if r.Method == http.MethodDelete && r.URL.Path == accountPath {
-		return s.ensureUser(s.handleAccountDelete)(w, r, v)
+		return s.ensureUser(s.withAccountSync(s.handleAccountDelete))(w, r, v)
 	} else if r.Method == http.MethodPost && r.URL.Path == accountPasswordPath {
 	} else if r.Method == http.MethodPost && r.URL.Path == accountPasswordPath {
 		return s.ensureUser(s.handleAccountPasswordChange)(w, r, v)
 		return s.ensureUser(s.handleAccountPasswordChange)(w, r, v)
 	} else if r.Method == http.MethodPatch && r.URL.Path == accountTokenPath {
 	} else if r.Method == http.MethodPatch && r.URL.Path == accountTokenPath {
@@ -368,27 +366,29 @@ func (s *Server) handleInternal(w http.ResponseWriter, r *http.Request, v *visit
 	} else if r.Method == http.MethodDelete && r.URL.Path == accountTokenPath {
 	} else if r.Method == http.MethodDelete && r.URL.Path == accountTokenPath {
 		return s.ensureUser(s.handleAccountTokenDelete)(w, r, v)
 		return s.ensureUser(s.handleAccountTokenDelete)(w, r, v)
 	} else if r.Method == http.MethodPatch && r.URL.Path == accountSettingsPath {
 	} else if r.Method == http.MethodPatch && r.URL.Path == accountSettingsPath {
-		return s.ensureUser(s.handleAccountSettingsChange)(w, r, v)
+		return s.ensureUser(s.withAccountSync(s.handleAccountSettingsChange))(w, r, v)
 	} else if r.Method == http.MethodPost && r.URL.Path == accountSubscriptionPath {
 	} else if r.Method == http.MethodPost && r.URL.Path == accountSubscriptionPath {
-		return s.ensureUser(s.handleAccountSubscriptionAdd)(w, r, v)
+		return s.ensureUser(s.withAccountSync(s.handleAccountSubscriptionAdd))(w, r, v)
 	} else if r.Method == http.MethodPatch && accountSubscriptionSingleRegex.MatchString(r.URL.Path) {
 	} else if r.Method == http.MethodPatch && accountSubscriptionSingleRegex.MatchString(r.URL.Path) {
-		return s.ensureUser(s.handleAccountSubscriptionChange)(w, r, v)
+		return s.ensureUser(s.withAccountSync(s.handleAccountSubscriptionChange))(w, r, v)
 	} else if r.Method == http.MethodDelete && accountSubscriptionSingleRegex.MatchString(r.URL.Path) {
 	} else if r.Method == http.MethodDelete && accountSubscriptionSingleRegex.MatchString(r.URL.Path) {
-		return s.ensureUser(s.handleAccountSubscriptionDelete)(w, r, v)
+		return s.ensureUser(s.withAccountSync(s.handleAccountSubscriptionDelete))(w, r, v)
 	} else if r.Method == http.MethodPost && r.URL.Path == accountReservationPath {
 	} else if r.Method == http.MethodPost && r.URL.Path == accountReservationPath {
-		return s.ensureUser(s.handleAccountReservationAdd)(w, r, v)
+		return s.ensureUser(s.withAccountSync(s.handleAccountReservationAdd))(w, r, v)
 	} else if r.Method == http.MethodDelete && accountReservationSingleRegex.MatchString(r.URL.Path) {
 	} else if r.Method == http.MethodDelete && accountReservationSingleRegex.MatchString(r.URL.Path) {
-		return s.ensureUser(s.handleAccountReservationDelete)(w, r, v)
+		return s.ensureUser(s.withAccountSync(s.handleAccountReservationDelete))(w, r, v)
 	} else if r.Method == http.MethodPost && r.URL.Path == accountBillingSubscriptionPath {
 	} else if r.Method == http.MethodPost && r.URL.Path == accountBillingSubscriptionPath {
-		return s.ensureUser(s.handleAccountBillingSubscriptionChange)(w, r, v)
-	} else if r.Method == http.MethodDelete && r.URL.Path == accountBillingSubscriptionPath {
-		return s.ensureStripeCustomer(s.handleAccountBillingSubscriptionDelete)(w, r, v)
+		return s.ensurePaymentsEnabled(s.ensureUser(s.handleAccountBillingSubscriptionCreate))(w, r, v) // Account sync via incoming Stripe webhook
 	} else if r.Method == http.MethodGet && accountBillingSubscriptionCheckoutSuccessRegex.MatchString(r.URL.Path) {
 	} else if r.Method == http.MethodGet && accountBillingSubscriptionCheckoutSuccessRegex.MatchString(r.URL.Path) {
-		return s.ensureUserManager(s.handleAccountCheckoutSessionSuccessGet)(w, r, v) // No user context!
+		return s.ensurePaymentsEnabled(s.ensureUserManager(s.handleAccountBillingSubscriptionCreateSuccess))(w, r, v) // No user context!
+	} else if r.Method == http.MethodPut && r.URL.Path == accountBillingSubscriptionPath {
+		return s.ensurePaymentsEnabled(s.ensureUser(s.handleAccountBillingSubscriptionUpdate))(w, r, v) // Account sync via incoming Stripe webhook
+	} else if r.Method == http.MethodDelete && r.URL.Path == accountBillingSubscriptionPath {
+		return s.ensurePaymentsEnabled(s.ensureStripeCustomer(s.handleAccountBillingSubscriptionDelete))(w, r, v) // Account sync via incoming Stripe webhook
 	} else if r.Method == http.MethodPost && r.URL.Path == accountBillingPortalPath {
 	} else if r.Method == http.MethodPost && r.URL.Path == accountBillingPortalPath {
-		return s.ensureStripeCustomer(s.handleAccountBillingPortalSessionCreate)(w, r, v)
+		return s.ensurePaymentsEnabled(s.ensureStripeCustomer(s.handleAccountBillingPortalSessionCreate))(w, r, v)
 	} else if r.Method == http.MethodPost && r.URL.Path == accountBillingWebhookPath {
 	} else if r.Method == http.MethodPost && r.URL.Path == accountBillingWebhookPath {
-		return s.ensureUserManager(s.handleAccountBillingWebhook)(w, r, v)
+		return s.ensurePaymentsEnabled(s.ensureUserManager(s.handleAccountBillingWebhook))(w, r, v)
 	} else if r.Method == http.MethodGet && r.URL.Path == matrixPushPath {
 	} else if r.Method == http.MethodGet && r.URL.Path == matrixPushPath {
 		return s.handleMatrixDiscovery(w)
 		return s.handleMatrixDiscovery(w)
 	} else if r.Method == http.MethodGet && staticRegex.MatchString(r.URL.Path) {
 	} else if r.Method == http.MethodGet && staticRegex.MatchString(r.URL.Path) {
@@ -1423,12 +1423,12 @@ func (s *Server) sendDelayedMessages() error {
 	for _, m := range messages {
 	for _, m := range messages {
 		var v *visitor
 		var v *visitor
 		if s.userManager != nil && m.User != "" {
 		if s.userManager != nil && m.User != "" {
-			user, err := s.userManager.User(m.User)
+			u, err := s.userManager.User(m.User)
 			if err != nil {
 			if err != nil {
 				log.Warn("%s Error sending delayed message: %s", logMessagePrefix(v, m), err.Error())
 				log.Warn("%s Error sending delayed message: %s", logMessagePrefix(v, m), err.Error())
 				continue
 				continue
 			}
 			}
-			v = s.visitorFromUser(user, m.Sender)
+			v = s.visitorFromUser(u, m.Sender)
 		} else {
 		} else {
 			v = s.visitorFromIP(m.Sender)
 			v = s.visitorFromIP(m.Sender)
 		}
 		}
@@ -1475,42 +1475,6 @@ func (s *Server) limitRequests(next handleFunc) handleFunc {
 	}
 	}
 }
 }
 
 
-func (s *Server) ensureWebEnabled(next handleFunc) handleFunc {
-	return func(w http.ResponseWriter, r *http.Request, v *visitor) error {
-		if !s.config.EnableWeb {
-			return errHTTPNotFound
-		}
-		return next(w, r, v)
-	}
-}
-
-func (s *Server) ensureUserManager(next handleFunc) handleFunc {
-	return func(w http.ResponseWriter, r *http.Request, v *visitor) error {
-		if s.userManager == nil {
-			return errHTTPNotFound
-		}
-		return next(w, r, v)
-	}
-}
-
-func (s *Server) ensureUser(next handleFunc) handleFunc {
-	return s.ensureUserManager(func(w http.ResponseWriter, r *http.Request, v *visitor) error {
-		if v.user == nil {
-			return errHTTPUnauthorized
-		}
-		return next(w, r, v)
-	})
-}
-
-func (s *Server) ensureStripeCustomer(next handleFunc) handleFunc {
-	return s.ensureUser(func(w http.ResponseWriter, r *http.Request, v *visitor) error {
-		if v.user.Billing.StripeCustomerID == "" {
-			return errHTTPBadRequestNotAPaidUser
-		}
-		return next(w, r, v)
-	})
-}
-
 // transformBodyJSON peeks the request body, reads the JSON, and converts it to headers
 // transformBodyJSON peeks the request body, reads the JSON, and converts it to headers
 // before passing it on to the next handler. This is meant to be used in combination with handlePublish.
 // before passing it on to the next handler. This is meant to be used in combination with handlePublish.
 func (s *Server) transformBodyJSON(next handleFunc) handleFunc {
 func (s *Server) transformBodyJSON(next handleFunc) handleFunc {

+ 10 - 2
server/server.yml

@@ -164,12 +164,10 @@
 # - enable-signup allows users to sign up via the web app, or API
 # - enable-signup allows users to sign up via the web app, or API
 # - enable-login allows users to log in via the web app, or API
 # - enable-login allows users to log in via the web app, or API
 # - enable-reservations allows users to reserve topics (if their tier allows it)
 # - enable-reservations allows users to reserve topics (if their tier allows it)
-# - enable-payments enables payments integration [preliminary option, may change]
 #
 #
 # enable-signup: false
 # enable-signup: false
 # enable-login: false
 # enable-login: false
 # enable-reservations: false
 # enable-reservations: false
-# enable-payments: false
 
 
 # Server URL of a Firebase/APNS-connected ntfy server (likely "https://ntfy.sh").
 # Server URL of a Firebase/APNS-connected ntfy server (likely "https://ntfy.sh").
 #
 #
@@ -216,6 +214,16 @@
 # visitor-attachment-total-size-limit: "100M"
 # visitor-attachment-total-size-limit: "100M"
 # visitor-attachment-daily-bandwidth-limit: "500M"
 # visitor-attachment-daily-bandwidth-limit: "500M"
 
 
+# Payments integration via Stripe
+#
+# - stripe-secret-key is the key used for the Stripe API communication. Setting this values
+#   enables payments in the ntfy web app (e.g. Upgrade dialog). See https://dashboard.stripe.com/apikeys.
+# - stripe-webhook-key is the key required to validate the authenticity of incoming webhooks from Stripe.
+#   Webhooks are essential up keep the local database in sync with the payment provider. See https://dashboard.stripe.com/webhooks.
+#
+# stripe-secret-key:
+# stripe-webhook-key:
+
 # Log level, can be TRACE, DEBUG, INFO, WARN or ERROR
 # Log level, can be TRACE, DEBUG, INFO, WARN or ERROR
 # This option can be hot-reloaded by calling "kill -HUP $pid" or "systemctl reload ntfy".
 # This option can be hot-reloaded by calling "kill -HUP $pid" or "systemctl reload ntfy".
 #
 #

+ 40 - 3
server/server_account.go

@@ -2,15 +2,18 @@ package server
 
 
 import (
 import (
 	"encoding/json"
 	"encoding/json"
+	"errors"
+	"heckel.io/ntfy/log"
 	"heckel.io/ntfy/user"
 	"heckel.io/ntfy/user"
 	"heckel.io/ntfy/util"
 	"heckel.io/ntfy/util"
 	"net/http"
 	"net/http"
 )
 )
 
 
 const (
 const (
-	jsonBodyBytesLimit   = 4096
-	subscriptionIDLength = 16
-	createdByAPI         = "api"
+	jsonBodyBytesLimit        = 4096
+	subscriptionIDLength      = 16
+	createdByAPI              = "api"
+	syncTopicAccountSyncEvent = "sync"
 )
 )
 
 
 func (s *Server) handleAccountCreate(w http.ResponseWriter, r *http.Request, v *visitor) error {
 func (s *Server) handleAccountCreate(w http.ResponseWriter, r *http.Request, v *visitor) error {
@@ -395,3 +398,37 @@ func (s *Server) handleAccountReservationDelete(w http.ResponseWriter, r *http.R
 	w.Header().Set("Access-Control-Allow-Origin", "*") // FIXME remove this
 	w.Header().Set("Access-Control-Allow-Origin", "*") // FIXME remove this
 	return nil
 	return nil
 }
 }
+
+func (s *Server) publishSyncEvent(v *visitor) error {
+	if v.user == nil || v.user.SyncTopic == "" {
+		return nil
+	}
+	log.Trace("Publishing sync event to user %s's sync topic %s", v.user.Name, v.user.SyncTopic)
+	topics, err := s.topicsFromIDs(v.user.SyncTopic)
+	if err != nil {
+		return err
+	} else if len(topics) == 0 {
+		return errors.New("cannot retrieve sync topic")
+	}
+	syncTopic := topics[0]
+	messageBytes, err := json.Marshal(&apiAccountSyncTopicResponse{Event: syncTopicAccountSyncEvent})
+	if err != nil {
+		return err
+	}
+	m := newDefaultMessage(syncTopic.ID, string(messageBytes))
+	if err := syncTopic.Publish(v, m); err != nil {
+		return err
+	}
+	return nil
+}
+
+func (s *Server) publishSyncEventAsync(v *visitor) {
+	go func() {
+		if v.user == nil || v.user.SyncTopic == "" {
+			return
+		}
+		if err := s.publishSyncEvent(v); err != nil {
+			log.Trace("Error publishing to user %s's sync topic %s: %s", v.user.Name, v.user.SyncTopic, err.Error())
+		}
+	}()
+}

+ 63 - 0
server/server_middleware.go

@@ -0,0 +1,63 @@
+package server
+
+import (
+	"net/http"
+)
+
+func (s *Server) ensureWebEnabled(next handleFunc) handleFunc {
+	return func(w http.ResponseWriter, r *http.Request, v *visitor) error {
+		if !s.config.EnableWeb {
+			return errHTTPNotFound
+		}
+		return next(w, r, v)
+	}
+}
+
+func (s *Server) ensureUserManager(next handleFunc) handleFunc {
+	return func(w http.ResponseWriter, r *http.Request, v *visitor) error {
+		if s.userManager == nil {
+			return errHTTPNotFound
+		}
+		return next(w, r, v)
+	}
+}
+
+func (s *Server) ensureUser(next handleFunc) handleFunc {
+	return s.ensureUserManager(func(w http.ResponseWriter, r *http.Request, v *visitor) error {
+		if v.user == nil {
+			return errHTTPUnauthorized
+		}
+		return next(w, r, v)
+	})
+}
+
+func (s *Server) ensurePaymentsEnabled(next handleFunc) handleFunc {
+	return func(w http.ResponseWriter, r *http.Request, v *visitor) error {
+		if !s.config.EnablePayments {
+			return errHTTPNotFound
+		}
+		return next(w, r, v)
+	}
+}
+
+func (s *Server) ensureStripeCustomer(next handleFunc) handleFunc {
+	return s.ensureUser(func(w http.ResponseWriter, r *http.Request, v *visitor) error {
+		if v.user.Billing.StripeCustomerID == "" {
+			return errHTTPBadRequestNotAPaidUser
+		}
+		return next(w, r, v)
+	})
+}
+
+func (s *Server) withAccountSync(next handleFunc) handleFunc {
+	return func(w http.ResponseWriter, r *http.Request, v *visitor) error {
+		if v.user == nil {
+			return next(w, r, v)
+		}
+		err := next(w, r, v)
+		if err == nil {
+			s.publishSyncEventAsync(v)
+		}
+		return err
+	}
+}

+ 97 - 91
server/server_payments.go

@@ -6,13 +6,14 @@ import (
 	"github.com/stripe/stripe-go/v74"
 	"github.com/stripe/stripe-go/v74"
 	portalsession "github.com/stripe/stripe-go/v74/billingportal/session"
 	portalsession "github.com/stripe/stripe-go/v74/billingportal/session"
 	"github.com/stripe/stripe-go/v74/checkout/session"
 	"github.com/stripe/stripe-go/v74/checkout/session"
+	"github.com/stripe/stripe-go/v74/customer"
 	"github.com/stripe/stripe-go/v74/subscription"
 	"github.com/stripe/stripe-go/v74/subscription"
 	"github.com/stripe/stripe-go/v74/webhook"
 	"github.com/stripe/stripe-go/v74/webhook"
 	"github.com/tidwall/gjson"
 	"github.com/tidwall/gjson"
 	"heckel.io/ntfy/log"
 	"heckel.io/ntfy/log"
-	"heckel.io/ntfy/user"
 	"heckel.io/ntfy/util"
 	"heckel.io/ntfy/util"
 	"net/http"
 	"net/http"
+	"net/netip"
 	"time"
 	"time"
 )
 )
 
 
@@ -20,15 +21,13 @@ const (
 	stripeBodyBytesLimit = 16384
 	stripeBodyBytesLimit = 16384
 )
 )
 
 
-// handleAccountBillingSubscriptionChange facilitates all subscription/tier changes, including payment flows.
-//
-// FIXME this should be two functions!
-//
-// It handles two cases:
-// - Create subscription: Transition from a user without Stripe subscription to a paid subscription (Checkout flow)
-// - Change subscription: Switching between Stripe prices (& tiers) by changing the Stripe subscription
-func (s *Server) handleAccountBillingSubscriptionChange(w http.ResponseWriter, r *http.Request, v *visitor) error {
-	req, err := readJSONWithLimit[apiAccountTierChangeRequest](r.Body, jsonBodyBytesLimit)
+// handleAccountBillingSubscriptionCreate creates a Stripe checkout flow to create a user subscription. The tier
+// will be updated by a subsequent webhook from Stripe, once the subscription becomes active.
+func (s *Server) handleAccountBillingSubscriptionCreate(w http.ResponseWriter, r *http.Request, v *visitor) error {
+	if v.user.Billing.StripeSubscriptionID != "" {
+		return errors.New("subscription already exists") //FIXME
+	}
+	req, err := readJSONWithLimit[apiAccountBillingSubscriptionChangeRequest](r.Body, jsonBodyBytesLimit)
 	if err != nil {
 	if err != nil {
 		return err
 		return err
 	}
 	}
@@ -36,46 +35,21 @@ func (s *Server) handleAccountBillingSubscriptionChange(w http.ResponseWriter, r
 	if err != nil {
 	if err != nil {
 		return err
 		return err
 	}
 	}
-	if v.user.Billing.StripeSubscriptionID == "" && tier.StripePriceID != "" {
-		return s.handleAccountBillingSubscriptionAdd(w, v, tier)
-	} else if v.user.Billing.StripeSubscriptionID != "" {
-		return s.handleAccountBillingSubscriptionUpdate(w, v, tier)
-	}
-	return errors.New("invalid state")
-}
-
-// handleAccountBillingSubscriptionDelete facilitates downgrading a paid user to a tier-less user,
-// and cancelling the Stripe subscription entirely
-func (s *Server) handleAccountBillingSubscriptionDelete(w http.ResponseWriter, r *http.Request, v *visitor) error {
-	if v.user.Billing.StripeCustomerID == "" {
-		return errHTTPBadRequestNotAPaidUser
-	}
-	if v.user.Billing.StripeSubscriptionID != "" {
-		_, err := subscription.Cancel(v.user.Billing.StripeSubscriptionID, nil)
-		if err != nil {
-			return err
-		}
-	}
-	if err := s.userManager.ResetTier(v.user.Name); err != nil {
-		return err
-	}
-	v.user.Billing.StripeSubscriptionID = ""
-	v.user.Billing.StripeSubscriptionStatus = ""
-	v.user.Billing.StripeSubscriptionPaidUntil = time.Unix(0, 0)
-	v.user.Billing.StripeSubscriptionCancelAt = time.Unix(0, 0)
-	if err := s.userManager.ChangeBilling(v.user); err != nil {
-		return err
+	if tier.StripePriceID == "" {
+		return errors.New("invalid tier") //FIXME
 	}
 	}
-	return nil
-}
-
-func (s *Server) handleAccountBillingSubscriptionAdd(w http.ResponseWriter, v *visitor, tier *user.Tier) error {
 	log.Info("Stripe: No existing subscription, creating checkout flow")
 	log.Info("Stripe: No existing subscription, creating checkout flow")
 	var stripeCustomerID *string
 	var stripeCustomerID *string
 	if v.user.Billing.StripeCustomerID != "" {
 	if v.user.Billing.StripeCustomerID != "" {
 		stripeCustomerID = &v.user.Billing.StripeCustomerID
 		stripeCustomerID = &v.user.Billing.StripeCustomerID
+		stripeCustomer, err := customer.Get(v.user.Billing.StripeCustomerID, nil)
+		if err != nil {
+			return err
+		} else if stripeCustomer.Subscriptions != nil && len(stripeCustomer.Subscriptions.Data) > 0 {
+			return errors.New("customer cannot have more than one subscription") //FIXME
+		}
 	}
 	}
-	successURL := s.config.BaseURL + accountBillingSubscriptionCheckoutSuccessTemplate
+	successURL := s.config.BaseURL + "/account" //+ accountBillingSubscriptionCheckoutSuccessTemplate
 	params := &stripe.CheckoutSessionParams{
 	params := &stripe.CheckoutSessionParams{
 		Customer:          stripeCustomerID, // A user may have previously deleted their subscription
 		Customer:          stripeCustomerID, // A user may have previously deleted their subscription
 		ClientReferenceID: &v.user.Name,     // FIXME Should be user ID
 		ClientReferenceID: &v.user.Name,     // FIXME Should be user ID
@@ -106,36 +80,7 @@ func (s *Server) handleAccountBillingSubscriptionAdd(w http.ResponseWriter, v *v
 	return nil
 	return nil
 }
 }
 
 
-func (s *Server) handleAccountBillingSubscriptionUpdate(w http.ResponseWriter, v *visitor, tier *user.Tier) error {
-	log.Info("Stripe: Changing tier and subscription to %s", tier.Code)
-	sub, err := subscription.Get(v.user.Billing.StripeSubscriptionID, nil)
-	if err != nil {
-		return err
-	}
-	params := &stripe.SubscriptionParams{
-		CancelAtPeriodEnd: stripe.Bool(false),
-		ProrationBehavior: stripe.String(string(stripe.SubscriptionSchedulePhaseProrationBehaviorCreateProrations)),
-		Items: []*stripe.SubscriptionItemsParams{
-			{
-				ID:    stripe.String(sub.Items.Data[0].ID),
-				Price: stripe.String(tier.StripePriceID),
-			},
-		},
-	}
-	_, err = subscription.Update(sub.ID, params)
-	if err != nil {
-		return err
-	}
-	response := &apiAccountCheckoutResponse{}
-	w.Header().Set("Content-Type", "application/json")
-	w.Header().Set("Access-Control-Allow-Origin", "*") // FIXME remove this
-	if err := json.NewEncoder(w).Encode(response); err != nil {
-		return err
-	}
-	return nil
-}
-
-func (s *Server) handleAccountCheckoutSessionSuccessGet(w http.ResponseWriter, r *http.Request, v *visitor) error {
+func (s *Server) handleAccountBillingSubscriptionCreateSuccess(w http.ResponseWriter, r *http.Request, _ *visitor) error {
 	// We don't have a v.user in this endpoint, only a userManager!
 	// We don't have a v.user in this endpoint, only a userManager!
 	matches := accountBillingSubscriptionCheckoutSuccessRegex.FindStringSubmatch(r.URL.Path)
 	matches := accountBillingSubscriptionCheckoutSuccessRegex.FindStringSubmatch(r.URL.Path)
 	if len(matches) != 2 {
 	if len(matches) != 2 {
@@ -183,6 +128,66 @@ func (s *Server) handleAccountCheckoutSessionSuccessGet(w http.ResponseWriter, r
 	return nil
 	return nil
 }
 }
 
 
+// handleAccountBillingSubscriptionUpdate updates an existing Stripe subscription to a new price, and updates
+// a user's tier accordingly. This endpoint only works if there is an existing subscription.
+func (s *Server) handleAccountBillingSubscriptionUpdate(w http.ResponseWriter, r *http.Request, v *visitor) error {
+	if v.user.Billing.StripeSubscriptionID != "" {
+		return errors.New("no existing subscription for user")
+	}
+	req, err := readJSONWithLimit[apiAccountBillingSubscriptionChangeRequest](r.Body, jsonBodyBytesLimit)
+	if err != nil {
+		return err
+	}
+	tier, err := s.userManager.Tier(req.Tier)
+	if err != nil {
+		return err
+	}
+	log.Info("Stripe: Changing tier and subscription to %s", tier.Code)
+	sub, err := subscription.Get(v.user.Billing.StripeSubscriptionID, nil)
+	if err != nil {
+		return err
+	}
+	params := &stripe.SubscriptionParams{
+		CancelAtPeriodEnd: stripe.Bool(false),
+		ProrationBehavior: stripe.String(string(stripe.SubscriptionSchedulePhaseProrationBehaviorCreateProrations)),
+		Items: []*stripe.SubscriptionItemsParams{
+			{
+				ID:    stripe.String(sub.Items.Data[0].ID),
+				Price: stripe.String(tier.StripePriceID),
+			},
+		},
+	}
+	_, err = subscription.Update(sub.ID, params)
+	if err != nil {
+		return err
+	}
+	response := &apiAccountCheckoutResponse{} // FIXME
+	w.Header().Set("Content-Type", "application/json")
+	w.Header().Set("Access-Control-Allow-Origin", "*") // FIXME remove this
+	if err := json.NewEncoder(w).Encode(response); err != nil {
+		return err
+	}
+	return nil
+}
+
+// handleAccountBillingSubscriptionDelete facilitates downgrading a paid user to a tier-less user,
+// and cancelling the Stripe subscription entirely
+func (s *Server) handleAccountBillingSubscriptionDelete(w http.ResponseWriter, r *http.Request, v *visitor) error {
+	if v.user.Billing.StripeCustomerID == "" {
+		return errHTTPBadRequestNotAPaidUser
+	}
+	if v.user.Billing.StripeSubscriptionID != "" {
+		params := &stripe.SubscriptionParams{
+			CancelAtPeriodEnd: stripe.Bool(true),
+		}
+		_, err := subscription.Update(v.user.Billing.StripeSubscriptionID, params)
+		if err != nil {
+			return err
+		}
+	}
+	return nil
+}
+
 func (s *Server) handleAccountBillingPortalSessionCreate(w http.ResponseWriter, r *http.Request, v *visitor) error {
 func (s *Server) handleAccountBillingPortalSessionCreate(w http.ResponseWriter, r *http.Request, v *visitor) error {
 	if v.user.Billing.StripeCustomerID == "" {
 	if v.user.Billing.StripeCustomerID == "" {
 		return errHTTPBadRequestNotAPaidUser
 		return errHTTPBadRequestNotAPaidUser
@@ -206,8 +211,8 @@ func (s *Server) handleAccountBillingPortalSessionCreate(w http.ResponseWriter,
 	return nil
 	return nil
 }
 }
 
 
-func (s *Server) handleAccountBillingWebhook(w http.ResponseWriter, r *http.Request, v *visitor) error {
-	// We don't have a v.user in this endpoint, only a userManager!
+func (s *Server) handleAccountBillingWebhook(w http.ResponseWriter, r *http.Request, _ *visitor) error {
+	// Note that the visitor (v) in this endpoint is the Stripe API, so we don't have v.user available
 	stripeSignature := r.Header.Get("Stripe-Signature")
 	stripeSignature := r.Header.Get("Stripe-Signature")
 	if stripeSignature == "" {
 	if stripeSignature == "" {
 		return errHTTPBadRequestInvalidStripeRequest
 		return errHTTPBadRequestInvalidStripeRequest
@@ -225,30 +230,27 @@ func (s *Server) handleAccountBillingWebhook(w http.ResponseWriter, r *http.Requ
 		return errHTTPBadRequestInvalidStripeRequest
 		return errHTTPBadRequestInvalidStripeRequest
 	}
 	}
 	log.Info("Stripe: webhook event %s received", event.Type)
 	log.Info("Stripe: webhook event %s received", event.Type)
-	stripeCustomerID := gjson.GetBytes(event.Data.Raw, "customer")
-	if !stripeCustomerID.Exists() {
-		return errHTTPBadRequestInvalidStripeRequest
-	}
 	switch event.Type {
 	switch event.Type {
 	case "customer.subscription.updated":
 	case "customer.subscription.updated":
-		return s.handleAccountBillingWebhookSubscriptionUpdated(stripeCustomerID.String(), event.Data.Raw)
+		return s.handleAccountBillingWebhookSubscriptionUpdated(event.Data.Raw)
 	case "customer.subscription.deleted":
 	case "customer.subscription.deleted":
-		return s.handleAccountBillingWebhookSubscriptionDeleted(stripeCustomerID.String(), event.Data.Raw)
+		return s.handleAccountBillingWebhookSubscriptionDeleted(event.Data.Raw)
 	default:
 	default:
 		return nil
 		return nil
 	}
 	}
 }
 }
 
 
-func (s *Server) handleAccountBillingWebhookSubscriptionUpdated(stripeCustomerID string, event json.RawMessage) error {
+func (s *Server) handleAccountBillingWebhookSubscriptionUpdated(event json.RawMessage) error {
+	subscriptionID := gjson.GetBytes(event, "id")
+	customerID := gjson.GetBytes(event, "customer")
 	status := gjson.GetBytes(event, "status")
 	status := gjson.GetBytes(event, "status")
 	currentPeriodEnd := gjson.GetBytes(event, "current_period_end")
 	currentPeriodEnd := gjson.GetBytes(event, "current_period_end")
 	cancelAt := gjson.GetBytes(event, "cancel_at")
 	cancelAt := gjson.GetBytes(event, "cancel_at")
 	priceID := gjson.GetBytes(event, "items.data.0.price.id")
 	priceID := gjson.GetBytes(event, "items.data.0.price.id")
-	if !status.Exists() || !currentPeriodEnd.Exists() || !cancelAt.Exists() || !priceID.Exists() {
+	if !subscriptionID.Exists() || !status.Exists() || !currentPeriodEnd.Exists() || !cancelAt.Exists() || !priceID.Exists() {
 		return errHTTPBadRequestInvalidStripeRequest
 		return errHTTPBadRequestInvalidStripeRequest
 	}
 	}
-	log.Info("Stripe: customer %s: subscription updated to %s, with price %s", stripeCustomerID, status, priceID)
-	u, err := s.userManager.UserByStripeCustomer(stripeCustomerID)
+	u, err := s.userManager.UserByStripeCustomer(customerID.String())
 	if err != nil {
 	if err != nil {
 		return err
 		return err
 	}
 	}
@@ -259,22 +261,25 @@ func (s *Server) handleAccountBillingWebhookSubscriptionUpdated(stripeCustomerID
 	if err := s.userManager.ChangeTier(u.Name, tier.Code); err != nil {
 	if err := s.userManager.ChangeTier(u.Name, tier.Code); err != nil {
 		return err
 		return err
 	}
 	}
+	u.Billing.StripeSubscriptionID = subscriptionID.String()
 	u.Billing.StripeSubscriptionStatus = stripe.SubscriptionStatus(status.String())
 	u.Billing.StripeSubscriptionStatus = stripe.SubscriptionStatus(status.String())
 	u.Billing.StripeSubscriptionPaidUntil = time.Unix(currentPeriodEnd.Int(), 0)
 	u.Billing.StripeSubscriptionPaidUntil = time.Unix(currentPeriodEnd.Int(), 0)
 	u.Billing.StripeSubscriptionCancelAt = time.Unix(cancelAt.Int(), 0)
 	u.Billing.StripeSubscriptionCancelAt = time.Unix(cancelAt.Int(), 0)
 	if err := s.userManager.ChangeBilling(u); err != nil {
 	if err := s.userManager.ChangeBilling(u); err != nil {
 		return err
 		return err
 	}
 	}
+	log.Info("Stripe: customer %s: subscription updated to %s, with price %s", customerID.String(), status, priceID)
+	s.publishSyncEventAsync(s.visitorFromUser(u, netip.IPv4Unspecified()))
 	return nil
 	return nil
 }
 }
 
 
-func (s *Server) handleAccountBillingWebhookSubscriptionDeleted(stripeCustomerID string, event json.RawMessage) error {
-	status := gjson.GetBytes(event, "status")
-	if !status.Exists() {
+func (s *Server) handleAccountBillingWebhookSubscriptionDeleted(event json.RawMessage) error {
+	stripeCustomerID := gjson.GetBytes(event, "customer")
+	if !stripeCustomerID.Exists() {
 		return errHTTPBadRequestInvalidStripeRequest
 		return errHTTPBadRequestInvalidStripeRequest
 	}
 	}
-	log.Info("Stripe: customer %s: subscription deleted, downgrading to unpaid tier", stripeCustomerID)
-	u, err := s.userManager.UserByStripeCustomer(stripeCustomerID)
+	log.Info("Stripe: customer %s: subscription deleted, downgrading to unpaid tier", stripeCustomerID.String())
+	u, err := s.userManager.UserByStripeCustomer(stripeCustomerID.String())
 	if err != nil {
 	if err != nil {
 		return err
 		return err
 	}
 	}
@@ -288,5 +293,6 @@ func (s *Server) handleAccountBillingWebhookSubscriptionDeleted(stripeCustomerID
 	if err := s.userManager.ChangeBilling(u); err != nil {
 	if err := s.userManager.ChangeBilling(u); err != nil {
 		return err
 		return err
 	}
 	}
+	s.publishSyncEventAsync(s.visitorFromUser(u, netip.IPv4Unspecified()))
 	return nil
 	return nil
 }
 }

+ 5 - 1
server/types.go

@@ -305,7 +305,7 @@ type apiConfigResponse struct {
 	DisallowedTopics   []string `json:"disallowed_topics"`
 	DisallowedTopics   []string `json:"disallowed_topics"`
 }
 }
 
 
-type apiAccountTierChangeRequest struct {
+type apiAccountBillingSubscriptionChangeRequest struct {
 	Tier string `json:"tier"`
 	Tier string `json:"tier"`
 }
 }
 
 
@@ -316,3 +316,7 @@ type apiAccountCheckoutResponse struct {
 type apiAccountBillingPortalRedirectResponse struct {
 type apiAccountBillingPortalRedirectResponse struct {
 	RedirectURL string `json:"redirect_url"`
 	RedirectURL string `json:"redirect_url"`
 }
 }
+
+type apiAccountSyncTopicResponse struct {
+	Event string `json:"event"`
+}

+ 12 - 15
user/manager.go

@@ -38,7 +38,6 @@ const (
 			id INTEGER PRIMARY KEY AUTOINCREMENT,		
 			id INTEGER PRIMARY KEY AUTOINCREMENT,		
 			code TEXT NOT NULL,
 			code TEXT NOT NULL,
 			name TEXT NOT NULL,
 			name TEXT NOT NULL,
-			paid INT NOT NULL,
 			messages_limit INT NOT NULL,
 			messages_limit INT NOT NULL,
 			messages_expiry_duration INT NOT NULL,
 			messages_expiry_duration INT NOT NULL,
 			emails_limit INT NOT NULL,
 			emails_limit INT NOT NULL,
@@ -104,20 +103,20 @@ const (
 	`
 	`
 
 
 	selectUserByNameQuery = `
 	selectUserByNameQuery = `
-		SELECT u.user, u.pass, u.role, u.prefs, u.sync_topic, u.stats_messages, u.stats_emails, u.stripe_customer_id, u.stripe_subscription_id, u.stripe_subscription_status, u.stripe_subscription_paid_until, u.stripe_subscription_cancel_at, p.code, p.name, p.paid, p.messages_limit, p.messages_expiry_duration, p.emails_limit, p.reservations_limit, p.attachment_file_size_limit, p.attachment_total_size_limit, p.attachment_expiry_duration, p.stripe_price_id
+		SELECT u.user, u.pass, u.role, u.prefs, u.sync_topic, u.stats_messages, u.stats_emails, u.stripe_customer_id, u.stripe_subscription_id, u.stripe_subscription_status, u.stripe_subscription_paid_until, u.stripe_subscription_cancel_at, p.code, p.name, p.messages_limit, p.messages_expiry_duration, p.emails_limit, p.reservations_limit, p.attachment_file_size_limit, p.attachment_total_size_limit, p.attachment_expiry_duration, p.stripe_price_id
 		FROM user u
 		FROM user u
 		LEFT JOIN tier p on p.id = u.tier_id
 		LEFT JOIN tier p on p.id = u.tier_id
 		WHERE user = ?		
 		WHERE user = ?		
 	`
 	`
 	selectUserByTokenQuery = `
 	selectUserByTokenQuery = `
-		SELECT u.user, u.pass, u.role, u.prefs, u.sync_topic, u.stats_messages, u.stats_emails, u.stripe_customer_id, u.stripe_subscription_id, u.stripe_subscription_status, u.stripe_subscription_paid_until, u.stripe_subscription_cancel_at , p.code, p.name, p.paid, p.messages_limit, p.messages_expiry_duration, p.emails_limit, p.reservations_limit, p.attachment_file_size_limit, p.attachment_total_size_limit, p.attachment_expiry_duration, p.stripe_price_id
+		SELECT u.user, u.pass, u.role, u.prefs, u.sync_topic, u.stats_messages, u.stats_emails, u.stripe_customer_id, u.stripe_subscription_id, u.stripe_subscription_status, u.stripe_subscription_paid_until, u.stripe_subscription_cancel_at, p.code, p.name, p.messages_limit, p.messages_expiry_duration, p.emails_limit, p.reservations_limit, p.attachment_file_size_limit, p.attachment_total_size_limit, p.attachment_expiry_duration, p.stripe_price_id
 		FROM user u
 		FROM user u
 		JOIN user_token t on u.id = t.user_id
 		JOIN user_token t on u.id = t.user_id
 		LEFT JOIN tier p on p.id = u.tier_id
 		LEFT JOIN tier p on p.id = u.tier_id
 		WHERE t.token = ? AND t.expires >= ?
 		WHERE t.token = ? AND t.expires >= ?
 	`
 	`
 	selectUserByStripeCustomerIDQuery = `
 	selectUserByStripeCustomerIDQuery = `
-		SELECT u.user, u.pass, u.role, u.prefs, u.sync_topic, u.stats_messages, u.stats_emails, u.stripe_customer_id, u.stripe_subscription_id, u.stripe_subscription_status, u.stripe_subscription_paid_until, u.stripe_subscription_cancel_at , p.code, p.name, p.paid, p.messages_limit, p.messages_expiry_duration, p.emails_limit, p.reservations_limit, p.attachment_file_size_limit, p.attachment_total_size_limit, p.attachment_expiry_duration, p.stripe_price_id
+		SELECT u.user, u.pass, u.role, u.prefs, u.sync_topic, u.stats_messages, u.stats_emails, u.stripe_customer_id, u.stripe_subscription_id, u.stripe_subscription_status, u.stripe_subscription_paid_until, u.stripe_subscription_cancel_at, p.code, p.name, p.messages_limit, p.messages_expiry_duration, p.emails_limit, p.reservations_limit, p.attachment_file_size_limit, p.attachment_total_size_limit, p.attachment_expiry_duration, p.stripe_price_id
 		FROM user u
 		FROM user u
 		LEFT JOIN tier p on p.id = u.tier_id
 		LEFT JOIN tier p on p.id = u.tier_id
 		WHERE u.stripe_customer_id = ?
 		WHERE u.stripe_customer_id = ?
@@ -218,17 +217,17 @@ const (
 	`
 	`
 
 
 	insertTierQuery = `
 	insertTierQuery = `
-		INSERT INTO tier (code, name, paid, messages_limit, messages_expiry_duration, emails_limit, reservations_limit, attachment_file_size_limit, attachment_total_size_limit, attachment_expiry_duration)
-		VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
+		INSERT INTO tier (code, name, messages_limit, messages_expiry_duration, emails_limit, reservations_limit, attachment_file_size_limit, attachment_total_size_limit, attachment_expiry_duration)
+		VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?)
 	`
 	`
 	selectTierIDQuery     = `SELECT id FROM tier WHERE code = ?`
 	selectTierIDQuery     = `SELECT id FROM tier WHERE code = ?`
 	selectTierByCodeQuery = `
 	selectTierByCodeQuery = `
-		SELECT code, name, paid, messages_limit, messages_expiry_duration, emails_limit, reservations_limit, attachment_file_size_limit, attachment_total_size_limit, attachment_expiry_duration, stripe_price_id
+		SELECT code, name, messages_limit, messages_expiry_duration, emails_limit, reservations_limit, attachment_file_size_limit, attachment_total_size_limit, attachment_expiry_duration, stripe_price_id
 		FROM tier
 		FROM tier
 		WHERE code = ?
 		WHERE code = ?
 	`
 	`
 	selectTierByPriceIDQuery = `
 	selectTierByPriceIDQuery = `
-		SELECT code, name, paid, messages_limit, messages_expiry_duration, emails_limit, reservations_limit, attachment_file_size_limit, attachment_total_size_limit, attachment_expiry_duration, stripe_price_id
+		SELECT code, name, messages_limit, messages_expiry_duration, emails_limit, reservations_limit, attachment_file_size_limit, attachment_total_size_limit, attachment_expiry_duration, stripe_price_id
 		FROM tier
 		FROM tier
 		WHERE stripe_price_id = ?
 		WHERE stripe_price_id = ?
 	`
 	`
@@ -606,13 +605,12 @@ func (a *Manager) readUser(rows *sql.Rows) (*User, error) {
 	defer rows.Close()
 	defer rows.Close()
 	var username, hash, role, prefs, syncTopic string
 	var username, hash, role, prefs, syncTopic string
 	var stripeCustomerID, stripeSubscriptionID, stripeSubscriptionStatus, stripePriceID, tierCode, tierName sql.NullString
 	var stripeCustomerID, stripeSubscriptionID, stripeSubscriptionStatus, stripePriceID, tierCode, tierName sql.NullString
-	var paid sql.NullBool
 	var messages, emails int64
 	var messages, emails int64
 	var messagesLimit, messagesExpiryDuration, emailsLimit, reservationsLimit, attachmentFileSizeLimit, attachmentTotalSizeLimit, attachmentExpiryDuration, stripeSubscriptionPaidUntil, stripeSubscriptionCancelAt sql.NullInt64
 	var messagesLimit, messagesExpiryDuration, emailsLimit, reservationsLimit, attachmentFileSizeLimit, attachmentTotalSizeLimit, attachmentExpiryDuration, stripeSubscriptionPaidUntil, stripeSubscriptionCancelAt sql.NullInt64
 	if !rows.Next() {
 	if !rows.Next() {
 		return nil, ErrUserNotFound
 		return nil, ErrUserNotFound
 	}
 	}
-	if err := rows.Scan(&username, &hash, &role, &prefs, &syncTopic, &messages, &emails, &stripeCustomerID, &stripeSubscriptionID, &stripeSubscriptionStatus, &stripeSubscriptionPaidUntil, &stripeSubscriptionCancelAt, &tierCode, &tierName, &paid, &messagesLimit, &messagesExpiryDuration, &emailsLimit, &reservationsLimit, &attachmentFileSizeLimit, &attachmentTotalSizeLimit, &attachmentExpiryDuration, &stripePriceID); err != nil {
+	if err := rows.Scan(&username, &hash, &role, &prefs, &syncTopic, &messages, &emails, &stripeCustomerID, &stripeSubscriptionID, &stripeSubscriptionStatus, &stripeSubscriptionPaidUntil, &stripeSubscriptionCancelAt, &tierCode, &tierName, &messagesLimit, &messagesExpiryDuration, &emailsLimit, &reservationsLimit, &attachmentFileSizeLimit, &attachmentTotalSizeLimit, &attachmentExpiryDuration, &stripePriceID); err != nil {
 		return nil, err
 		return nil, err
 	} else if err := rows.Err(); err != nil {
 	} else if err := rows.Err(); err != nil {
 		return nil, err
 		return nil, err
@@ -643,7 +641,7 @@ func (a *Manager) readUser(rows *sql.Rows) (*User, error) {
 		user.Tier = &Tier{
 		user.Tier = &Tier{
 			Code:                     tierCode.String,
 			Code:                     tierCode.String,
 			Name:                     tierName.String,
 			Name:                     tierName.String,
-			Paid:                     paid.Bool,
+			Paid:                     stripePriceID.Valid, // If there is a price, it's a paid tier
 			MessagesLimit:            messagesLimit.Int64,
 			MessagesLimit:            messagesLimit.Int64,
 			MessagesExpiryDuration:   time.Duration(messagesExpiryDuration.Int64) * time.Second,
 			MessagesExpiryDuration:   time.Duration(messagesExpiryDuration.Int64) * time.Second,
 			EmailsLimit:              emailsLimit.Int64,
 			EmailsLimit:              emailsLimit.Int64,
@@ -870,7 +868,7 @@ func (a *Manager) DefaultAccess() Permission {
 
 
 // CreateTier creates a new tier in the database
 // CreateTier creates a new tier in the database
 func (a *Manager) CreateTier(tier *Tier) error {
 func (a *Manager) CreateTier(tier *Tier) error {
-	if _, err := a.db.Exec(insertTierQuery, tier.Code, tier.Name, tier.Paid, tier.MessagesLimit, int64(tier.MessagesExpiryDuration.Seconds()), tier.EmailsLimit, tier.ReservationsLimit, tier.AttachmentFileSizeLimit, tier.AttachmentTotalSizeLimit, int64(tier.AttachmentExpiryDuration.Seconds())); err != nil {
+	if _, err := a.db.Exec(insertTierQuery, tier.Code, tier.Name, tier.MessagesLimit, int64(tier.MessagesExpiryDuration.Seconds()), tier.EmailsLimit, tier.ReservationsLimit, tier.AttachmentFileSizeLimit, tier.AttachmentTotalSizeLimit, int64(tier.AttachmentExpiryDuration.Seconds())); err != nil {
 		return err
 		return err
 	}
 	}
 	return nil
 	return nil
@@ -903,12 +901,11 @@ func (a *Manager) readTier(rows *sql.Rows) (*Tier, error) {
 	defer rows.Close()
 	defer rows.Close()
 	var code, name string
 	var code, name string
 	var stripePriceID sql.NullString
 	var stripePriceID sql.NullString
-	var paid bool
 	var messagesLimit, messagesExpiryDuration, emailsLimit, reservationsLimit, attachmentFileSizeLimit, attachmentTotalSizeLimit, attachmentExpiryDuration sql.NullInt64
 	var messagesLimit, messagesExpiryDuration, emailsLimit, reservationsLimit, attachmentFileSizeLimit, attachmentTotalSizeLimit, attachmentExpiryDuration sql.NullInt64
 	if !rows.Next() {
 	if !rows.Next() {
 		return nil, ErrTierNotFound
 		return nil, ErrTierNotFound
 	}
 	}
-	if err := rows.Scan(&code, &name, &paid, &messagesLimit, &messagesExpiryDuration, &emailsLimit, &reservationsLimit, &attachmentFileSizeLimit, &attachmentTotalSizeLimit, &attachmentExpiryDuration, &stripePriceID); err != nil {
+	if err := rows.Scan(&code, &name, &messagesLimit, &messagesExpiryDuration, &emailsLimit, &reservationsLimit, &attachmentFileSizeLimit, &attachmentTotalSizeLimit, &attachmentExpiryDuration, &stripePriceID); err != nil {
 		return nil, err
 		return nil, err
 	} else if err := rows.Err(); err != nil {
 	} else if err := rows.Err(); err != nil {
 		return nil, err
 		return nil, err
@@ -917,7 +914,7 @@ func (a *Manager) readTier(rows *sql.Rows) (*Tier, error) {
 	return &Tier{
 	return &Tier{
 		Code:                     code,
 		Code:                     code,
 		Name:                     name,
 		Name:                     name,
-		Paid:                     paid,
+		Paid:                     stripePriceID.Valid, // If there is a price, it's a paid tier
 		MessagesLimit:            messagesLimit.Int64,
 		MessagesLimit:            messagesLimit.Int64,
 		MessagesExpiryDuration:   time.Duration(messagesExpiryDuration.Int64) * time.Second,
 		MessagesExpiryDuration:   time.Duration(messagesExpiryDuration.Int64) * time.Second,
 		EmailsLimit:              emailsLimit.Int64,
 		EmailsLimit:              emailsLimit.Int64,

+ 5 - 1
web/public/static/langs/en.json

@@ -179,8 +179,10 @@
   "account_usage_unlimited": "Unlimited",
   "account_usage_unlimited": "Unlimited",
   "account_usage_limits_reset_daily": "Usage limits are reset daily at midnight (UTC)",
   "account_usage_limits_reset_daily": "Usage limits are reset daily at midnight (UTC)",
   "account_usage_tier_title": "Account type",
   "account_usage_tier_title": "Account type",
+  "account_usage_tier_description": "Your account's power level",
   "account_usage_tier_admin": "Admin",
   "account_usage_tier_admin": "Admin",
-  "account_usage_tier_none": "Basic",
+  "account_usage_tier_basic": "Basic",
+  "account_usage_tier_free": "Free",
   "account_usage_tier_upgrade_button": "Upgrade to Pro",
   "account_usage_tier_upgrade_button": "Upgrade to Pro",
   "account_usage_tier_change_button": "Change",
   "account_usage_tier_change_button": "Change",
   "account_usage_tier_paid_until": "Subscription paid until {{date}}, and will auto-renew",
   "account_usage_tier_paid_until": "Subscription paid until {{date}}, and will auto-renew",
@@ -199,6 +201,8 @@
   "account_delete_dialog_label": "Type '{{username}}' to delete account",
   "account_delete_dialog_label": "Type '{{username}}' to delete account",
   "account_delete_dialog_button_cancel": "Cancel",
   "account_delete_dialog_button_cancel": "Cancel",
   "account_delete_dialog_button_submit": "Permanently delete account",
   "account_delete_dialog_button_submit": "Permanently delete account",
+  "account_upgrade_dialog_title": "Change billing plan",
+  "account_upgrade_dialog_cancel_warning": "This will cancel your subscription, and downgrade your account on {{date}}. On that date, topic reservations as well as messages cached on the server will be deleted.",
   "prefs_notifications_title": "Notifications",
   "prefs_notifications_title": "Notifications",
   "prefs_notifications_sound_title": "Notification sound",
   "prefs_notifications_sound_title": "Notification sound",
   "prefs_notifications_sound_description_none": "Notifications do not play any sound when they arrive",
   "prefs_notifications_sound_description_none": "Notifications do not play any sound when they arrive",

+ 13 - 3
web/src/app/AccountApi.js

@@ -264,11 +264,20 @@ class AccountApi {
         this.triggerChange(); // Dangle!
         this.triggerChange(); // Dangle!
     }
     }
 
 
+    async createBillingSubscription(tier) {
+        console.log(`[AccountApi] Creating billing subscription with ${tier}`);
+        return await this.upsertBillingSubscription("POST", tier)
+    }
+
     async updateBillingSubscription(tier) {
     async updateBillingSubscription(tier) {
+        console.log(`[AccountApi] Updating billing subscription with ${tier}`);
+        return await this.upsertBillingSubscription("PUT", tier)
+    }
+
+    async upsertBillingSubscription(method, tier) {
         const url = accountBillingSubscriptionUrl(config.base_url);
         const url = accountBillingSubscriptionUrl(config.base_url);
-        console.log(`[AccountApi] Requesting tier change to ${tier}`);
         const response = await fetch(url, {
         const response = await fetch(url, {
-            method: "POST",
+            method: method,
             headers: withBearerAuth({}, session.token()),
             headers: withBearerAuth({}, session.token()),
             body: JSON.stringify({
             body: JSON.stringify({
                 tier: tier
                 tier: tier
@@ -284,7 +293,7 @@ class AccountApi {
 
 
     async deleteBillingSubscription() {
     async deleteBillingSubscription() {
         const url = accountBillingSubscriptionUrl(config.base_url);
         const url = accountBillingSubscriptionUrl(config.base_url);
-        console.log(`[AccountApi] Cancelling paid subscription`);
+        console.log(`[AccountApi] Cancelling billing subscription`);
         const response = await fetch(url, {
         const response = await fetch(url, {
             method: "DELETE",
             method: "DELETE",
             headers: withBearerAuth({}, session.token())
             headers: withBearerAuth({}, session.token())
@@ -345,6 +354,7 @@ class AccountApi {
     }
     }
 
 
     async triggerChange() {
     async triggerChange() {
+        return null;
         const account = await this.get();
         const account = await this.get();
         if (!account.sync_topic) {
         if (!account.sync_topic) {
             return;
             return;

+ 84 - 59
web/src/components/Account.js

@@ -56,6 +56,7 @@ const Basics = () => {
             <PrefGroup>
             <PrefGroup>
                 <Username/>
                 <Username/>
                 <ChangePassword/>
                 <ChangePassword/>
+                <AccountType/>
             </PrefGroup>
             </PrefGroup>
         </Card>
         </Card>
     );
     );
@@ -168,18 +169,20 @@ const ChangePasswordDialog = (props) => {
     );
     );
 };
 };
 
 
-const Stats = () => {
+const AccountType = () => {
     const { t } = useTranslation();
     const { t } = useTranslation();
     const { account } = useContext(AccountContext);
     const { account } = useContext(AccountContext);
+    const [upgradeDialogKey, setUpgradeDialogKey] = useState(0);
     const [upgradeDialogOpen, setUpgradeDialogOpen] = useState(false);
     const [upgradeDialogOpen, setUpgradeDialogOpen] = useState(false);
 
 
     if (!account) {
     if (!account) {
         return <></>;
         return <></>;
     }
     }
 
 
-    const normalize = (value, max) => {
-        return Math.min(value / max * 100, 100);
-    };
+    const handleUpgradeClick = () => {
+        setUpgradeDialogKey(k => k + 1);
+        setUpgradeDialogOpen(true);
+    }
 
 
     const handleManageBilling = async () => {
     const handleManageBilling = async () => {
         try {
         try {
@@ -194,67 +197,89 @@ const Stats = () => {
         }
         }
     };
     };
 
 
+    let accountType;
+    if (account.role === "admin") {
+        const tierSuffix = (account.tier) ? `(with ${account.tier.name} tier)` : `(no tier)`;
+        accountType = `${t("account_usage_tier_admin")} ${tierSuffix}`;
+    } else if (!account.tier) {
+        accountType = (config.enable_payments) ? t("account_usage_tier_free") : t("account_usage_tier_basic");
+    } else {
+        accountType = account.tier.name;
+    }
+
+    return (
+        <Pref
+            alignTop={account.billing?.status === "past_due" || account.billing?.cancel_at > 0}
+            title={t("account_usage_tier_title")}
+            description={t("account_usage_tier_description")}
+        >
+            <div>
+                {accountType}
+                {account.billing?.paid_until && !account.billing?.cancel_at &&
+                    <Tooltip title={t("account_usage_tier_paid_until", { date: formatShortDate(account.billing?.paid_until) })}>
+                        <span><InfoIcon/></span>
+                    </Tooltip>
+                }
+                {config.enable_payments && account.role === "user" && !account.billing?.subscription &&
+                    <Button
+                        variant="outlined"
+                        size="small"
+                        startIcon={<CelebrationIcon sx={{ color: "#55b86e" }}/>}
+                        onClick={handleUpgradeClick}
+                        sx={{ml: 1}}
+                    >{t("account_usage_tier_upgrade_button")}</Button>
+                }
+                {config.enable_payments && account.role === "user" && account.billing?.subscription &&
+                    <Button
+                        variant="outlined"
+                        size="small"
+                        onClick={handleUpgradeClick}
+                        sx={{ml: 1}}
+                    >{t("account_usage_tier_change_button")}</Button>
+                }
+                {config.enable_payments && account.role === "user" && account.billing?.customer &&
+                    <Button
+                        variant="outlined"
+                        size="small"
+                        onClick={handleManageBilling}
+                        sx={{ml: 1}}
+                    >{t("account_usage_manage_billing_button")}</Button>
+                }
+                <UpgradeDialog
+                    key={`upgradeDialogFromAccount${upgradeDialogKey}`}
+                    open={upgradeDialogOpen}
+                    onCancel={() => setUpgradeDialogOpen(false)}
+                />
+            </div>
+            {account.billing?.status === "past_due" &&
+                <Alert severity="error" sx={{mt: 1}}>{t("account_usage_tier_payment_overdue")}</Alert>
+            }
+            {account.billing?.cancel_at > 0 &&
+                <Alert severity="warning" sx={{mt: 1}}>{t("account_usage_tier_canceled_subscription", { date: formatShortDate(account.billing.cancel_at) })}</Alert>
+            }
+        </Pref>
+    )
+};
+
+const Stats = () => {
+    const { t } = useTranslation();
+    const { account } = useContext(AccountContext);
+    const [upgradeDialogOpen, setUpgradeDialogOpen] = useState(false);
+
+    if (!account) {
+        return <></>;
+    }
+
+    const normalize = (value, max) => {
+        return Math.min(value / max * 100, 100);
+    };
+
     return (
     return (
         <Card sx={{p: 3}} aria-label={t("account_usage_title")}>
         <Card sx={{p: 3}} aria-label={t("account_usage_title")}>
             <Typography variant="h5" sx={{marginBottom: 2}}>
             <Typography variant="h5" sx={{marginBottom: 2}}>
                 {t("account_usage_title")}
                 {t("account_usage_title")}
             </Typography>
             </Typography>
             <PrefGroup>
             <PrefGroup>
-                <Pref
-                    alignTop={account.billing?.status === "past_due" || account.billing?.cancel_at > 0}
-                    title={t("account_usage_tier_title")}
-                >
-                    <div>
-                        {account.role === "admin" &&
-                            <>
-                                {t("account_usage_tier_admin")}
-                                {" "}{account.tier ? `(with ${account.tier.name} tier)` : `(no tier)`}
-                            </>
-                        }
-                        {account.role === "user" && account.tier && account.tier.name}
-                        {account.role === "user" && !account.tier && t("account_usage_tier_none")}
-                        {account.billing?.paid_until &&
-                            <Tooltip title={t("account_usage_tier_paid_until", { date: formatShortDate(account.billing?.paid_until) })}>
-                                <span><InfoIcon/></span>
-                            </Tooltip>
-                        }
-                        {config.enable_payments && account.role === "user" && (!account.tier || !account.tier.paid) &&
-                            <Button
-                                variant="outlined"
-                                size="small"
-                                startIcon={<CelebrationIcon sx={{ color: "#55b86e" }}/>}
-                                onClick={() => setUpgradeDialogOpen(true)}
-                                sx={{ml: 1}}
-                            >{t("account_usage_tier_upgrade_button")}</Button>
-                        }
-                        {config.enable_payments && account.role === "user" && account.tier?.paid &&
-                            <Button
-                                variant="outlined"
-                                size="small"
-                                onClick={() => setUpgradeDialogOpen(true)}
-                                sx={{ml: 1}}
-                            >{t("account_usage_tier_change_button")}</Button>
-                        }
-                        {config.enable_payments && account.role === "user" && account.billing?.customer &&
-                            <Button
-                                variant="outlined"
-                                size="small"
-                                onClick={handleManageBilling}
-                                sx={{ml: 1}}
-                            >{t("account_usage_manage_billing_button")}</Button>
-                        }
-                        <UpgradeDialog
-                            open={upgradeDialogOpen}
-                            onCancel={() => setUpgradeDialogOpen(false)}
-                        />
-                    </div>
-                    {account.billing?.status === "past_due" &&
-                        <Alert severity="error" sx={{mt: 1}}>{t("account_usage_tier_payment_overdue")}</Alert>
-                    }
-                    {account.billing?.cancel_at > 0 &&
-                        <Alert severity="info" sx={{mt: 1}}>{t("account_usage_tier_canceled_subscription", { date: formatShortDate(account.billing.cancel_at) })}</Alert>
-                    }
-                </Pref>
                 {account.role !== "admin" &&
                 {account.role !== "admin" &&
                     <Pref title={t("account_usage_reservations_title")}>
                     <Pref title={t("account_usage_reservations_title")}>
                         {account.limits.reservations > 0 &&
                         {account.limits.reservations > 0 &&

+ 11 - 3
web/src/components/Navigation.js

@@ -103,8 +103,8 @@ const NavList = (props) => {
     };
     };
 
 
     const isAdmin = account?.role === "admin";
     const isAdmin = account?.role === "admin";
-    const isPaid = account?.tier?.paid;
-    const showUpgradeBanner = config.enable_payments && !isAdmin && !isPaid;// && (!props.account || !props.account.tier || !props.account.tier.paid || props.account);
+    const isPaid = account?.billing?.subscription;
+    const showUpgradeBanner = config.enable_payments && !isAdmin && !isPaid;
     const showSubscriptionsList = props.subscriptions?.length > 0;
     const showSubscriptionsList = props.subscriptions?.length > 0;
     const showNotificationBrowserNotSupportedBox = !notifier.browserSupported();
     const showNotificationBrowserNotSupportedBox = !notifier.browserSupported();
     const showNotificationContextNotSupportedBox = notifier.browserSupported() && !notifier.contextSupported(); // Only show if notifications are generally supported in the browser
     const showNotificationContextNotSupportedBox = notifier.browserSupported() && !notifier.contextSupported(); // Only show if notifications are generally supported in the browser
@@ -174,7 +174,14 @@ const NavList = (props) => {
 };
 };
 
 
 const UpgradeBanner = () => {
 const UpgradeBanner = () => {
+    const [dialogKey, setDialogKey] = useState(0);
     const [dialogOpen, setDialogOpen] = useState(false);
     const [dialogOpen, setDialogOpen] = useState(false);
+
+    const handleClick = () => {
+        setDialogKey(k => k + 1);
+        setDialogOpen(true);
+    };
+
     return (
     return (
         <Box sx={{
         <Box sx={{
             position: "fixed",
             position: "fixed",
@@ -184,7 +191,7 @@ const UpgradeBanner = () => {
             background: "linear-gradient(150deg, rgba(196, 228, 221, 0.46) 0%, rgb(255, 255, 255) 100%)",
             background: "linear-gradient(150deg, rgba(196, 228, 221, 0.46) 0%, rgb(255, 255, 255) 100%)",
         }}>
         }}>
             <Divider/>
             <Divider/>
-            <ListItemButton onClick={() => setDialogOpen(true)} sx={{pt: 2, pb: 2}}>
+            <ListItemButton onClick={handleClick} sx={{pt: 2, pb: 2}}>
                 <ListItemIcon><CelebrationIcon sx={{ color: "#55b86e" }} fontSize="large"/></ListItemIcon>
                 <ListItemIcon><CelebrationIcon sx={{ color: "#55b86e" }} fontSize="large"/></ListItemIcon>
                 <ListItemText
                 <ListItemText
                     sx={{ ml: 1 }}
                     sx={{ ml: 1 }}
@@ -207,6 +214,7 @@ const UpgradeBanner = () => {
                 />
                 />
             </ListItemButton>
             </ListItemButton>
             <UpgradeDialog
             <UpgradeDialog
+                key={`upgradeDialog${dialogKey}`}
                 open={dialogOpen}
                 open={dialogOpen}
                 onCancel={() => setDialogOpen(false)}
                 onCancel={() => setDialogOpen(false)}
             />
             />

+ 52 - 16
web/src/components/UpgradeDialog.js

@@ -2,7 +2,7 @@ import * as React from 'react';
 import Dialog from '@mui/material/Dialog';
 import Dialog from '@mui/material/Dialog';
 import DialogContent from '@mui/material/DialogContent';
 import DialogContent from '@mui/material/DialogContent';
 import DialogTitle from '@mui/material/DialogTitle';
 import DialogTitle from '@mui/material/DialogTitle';
-import {CardActionArea, CardContent, useMediaQuery} from "@mui/material";
+import {Alert, CardActionArea, CardContent, useMediaQuery} from "@mui/material";
 import theme from "./theme";
 import theme from "./theme";
 import DialogFooter from "./DialogFooter";
 import DialogFooter from "./DialogFooter";
 import Button from "@mui/material/Button";
 import Button from "@mui/material/Button";
@@ -13,28 +13,53 @@ import {useContext, useState} from "react";
 import Card from "@mui/material/Card";
 import Card from "@mui/material/Card";
 import Typography from "@mui/material/Typography";
 import Typography from "@mui/material/Typography";
 import {AccountContext} from "./App";
 import {AccountContext} from "./App";
+import {formatShortDate} from "../app/utils";
+import {useTranslation} from "react-i18next";
 
 
 const UpgradeDialog = (props) => {
 const UpgradeDialog = (props) => {
+    const { t } = useTranslation();
     const { account } = useContext(AccountContext);
     const { account } = useContext(AccountContext);
     const fullScreen = useMediaQuery(theme.breakpoints.down('sm'));
     const fullScreen = useMediaQuery(theme.breakpoints.down('sm'));
     const [newTier, setNewTier] = useState(account?.tier?.code || null);
     const [newTier, setNewTier] = useState(account?.tier?.code || null);
     const [errorText, setErrorText] = useState("");
     const [errorText, setErrorText] = useState("");
 
 
-    const handleCheckout = async () => {
+    if (!account) {
+        return <></>;
+    }
+
+    const currentTier = account.tier?.code || null;
+    let action, submitButtonLabel, submitButtonEnabled;
+    if (currentTier === newTier) {
+        submitButtonLabel = "Update subscription";
+        submitButtonEnabled = false;
+        action = null;
+    } else if (currentTier === null) {
+        submitButtonLabel = "Pay $5 now and subscribe";
+        submitButtonEnabled = true;
+        action = Action.CREATE;
+    } else if (newTier === null) {
+        submitButtonLabel = "Cancel subscription";
+        submitButtonEnabled = true;
+        action = Action.CANCEL;
+    } else {
+        submitButtonLabel = "Update subscription";
+        submitButtonEnabled = true;
+        action = Action.UPDATE;
+    }
+
+    const handleSubmit = async () => {
         try {
         try {
-            if (newTier == null) {
+            if (action === Action.CREATE) {
+                const response = await accountApi.createBillingSubscription(newTier);
+                window.location.href = response.redirect_url;
+            } else if (action === Action.UPDATE) {
+                await accountApi.updateBillingSubscription(newTier);
+            } else if (action === Action.CANCEL) {
                 await accountApi.deleteBillingSubscription();
                 await accountApi.deleteBillingSubscription();
-            } else {
-                const response = await accountApi.updateBillingSubscription(newTier);
-                if (response.redirect_url) {
-                    window.location.href = response.redirect_url;
-                } else {
-                    await accountApi.sync();
-                }
             }
             }
-
+            props.onCancel();
         } catch (e) {
         } catch (e) {
-            console.log(`[UpgradeDialog] Error creating checkout session`, e);
+            console.log(`[UpgradeDialog] Error changing billing subscription`, e);
             if ((e instanceof UnauthorizedError)) {
             if ((e instanceof UnauthorizedError)) {
                 session.resetAndRedirect(routes.login);
                 session.resetAndRedirect(routes.login);
             }
             }
@@ -44,7 +69,7 @@ const UpgradeDialog = (props) => {
 
 
     return (
     return (
         <Dialog open={props.open} onClose={props.onCancel} maxWidth="md" fullScreen={fullScreen}>
         <Dialog open={props.open} onClose={props.onCancel} maxWidth="md" fullScreen={fullScreen}>
-            <DialogTitle>Upgrade to Pro</DialogTitle>
+            <DialogTitle>Change billing plan</DialogTitle>
             <DialogContent>
             <DialogContent>
                 <div style={{
                 <div style={{
                     display: "flex",
                     display: "flex",
@@ -55,9 +80,15 @@ const UpgradeDialog = (props) => {
                     <TierCard code="pro" name={"Pro"} selected={newTier === "pro"} onClick={() => setNewTier("pro")}/>
                     <TierCard code="pro" name={"Pro"} selected={newTier === "pro"} onClick={() => setNewTier("pro")}/>
                     <TierCard code="business" name={"Business"} selected={newTier === "business"} onClick={() => setNewTier("business")}/>
                     <TierCard code="business" name={"Business"} selected={newTier === "business"} onClick={() => setNewTier("business")}/>
                 </div>
                 </div>
+                {action === Action.CANCEL &&
+                    <Alert severity="warning">
+                        {t("account_upgrade_dialog_cancel_warning", { date: formatShortDate(account.billing.paid_until) })}
+                    </Alert>
+                }
             </DialogContent>
             </DialogContent>
             <DialogFooter status={errorText}>
             <DialogFooter status={errorText}>
-                <Button onClick={handleCheckout}>Checkout</Button>
+                <Button onClick={props.onCancel}>Cancel</Button>
+                <Button onClick={handleSubmit} disabled={!submitButtonEnabled}>{submitButtonLabel}</Button>
             </DialogFooter>
             </DialogFooter>
         </Dialog>
         </Dialog>
     );
     );
@@ -65,8 +96,7 @@ const UpgradeDialog = (props) => {
 
 
 const TierCard = (props) => {
 const TierCard = (props) => {
     const cardStyle = (props.selected) ? {
     const cardStyle = (props.selected) ? {
-        border: "1px solid red",
-
+        background: "#eee"
     } : {};
     } : {};
     return (
     return (
         <Card sx={{ m: 1, maxWidth: 345 }}>
         <Card sx={{ m: 1, maxWidth: 345 }}>
@@ -85,4 +115,10 @@ const TierCard = (props) => {
     );
     );
 }
 }
 
 
+const Action = {
+    CREATE: 1,
+    UPDATE: 2,
+    CANCEL: 3
+};
+
 export default UpgradeDialog;
 export default UpgradeDialog;