Przeglądaj źródła

WIP: Stripe integration

binwiederhier 3 lat temu
rodzic
commit
01fd4754f9

+ 2 - 2
cmd/access.go

@@ -103,7 +103,7 @@ func changeAccess(c *cli.Context, manager *user.Manager, username string, topic
 	read := util.Contains([]string{"read-write", "rw", "read-only", "read", "ro"}, perms)
 	write := util.Contains([]string{"read-write", "rw", "write-only", "write", "wo"}, perms)
 	u, err := manager.User(username)
-	if err == user.ErrNotFound {
+	if err == user.ErrUserNotFound {
 		return fmt.Errorf("user %s does not exist", username)
 	} else if u.Role == user.RoleAdmin {
 		return fmt.Errorf("user %s is an admin user, access control entries have no effect", username)
@@ -173,7 +173,7 @@ func showAllAccess(c *cli.Context, manager *user.Manager) error {
 
 func showUserAccess(c *cli.Context, manager *user.Manager, username string) error {
 	users, err := manager.User(username)
-	if err == user.ErrNotFound {
+	if err == user.ErrUserNotFound {
 		return fmt.Errorf("user %s does not exist", username)
 	} else if err != nil {
 		return err

+ 17 - 4
cmd/serve.go

@@ -5,6 +5,7 @@ package cmd
 import (
 	"errors"
 	"fmt"
+	"github.com/stripe/stripe-go/v74"
 	"heckel.io/ntfy/user"
 	"io/fs"
 	"math"
@@ -61,7 +62,6 @@ var flagsServe = append(
 	altsrc.NewBoolFlag(&cli.BoolFlag{Name: "enable-signup", Aliases: []string{"enable_signup"}, EnvVars: []string{"NTFY_ENABLE_SIGNUP"}, Value: false, Usage: "allows users to sign up via the web app, or API"}),
 	altsrc.NewBoolFlag(&cli.BoolFlag{Name: "enable-login", Aliases: []string{"enable_login"}, EnvVars: []string{"NTFY_ENABLE_LOGIN"}, Value: false, Usage: "allows users to log in via the web app, or API"}),
 	altsrc.NewBoolFlag(&cli.BoolFlag{Name: "enable-reservations", Aliases: []string{"enable_reservations"}, EnvVars: []string{"NTFY_ENABLE_RESERVATIONS"}, Value: false, Usage: "allows users to reserve topics (if their tier allows it)"}),
-	altsrc.NewBoolFlag(&cli.BoolFlag{Name: "enable-payments", Aliases: []string{"enable_payments"}, EnvVars: []string{"NTFY_ENABLE_PAYMENTS"}, Value: false, Usage: "enables payments integration [preliminary option, may change]"}),
 	altsrc.NewStringFlag(&cli.StringFlag{Name: "upstream-base-url", Aliases: []string{"upstream_base_url"}, EnvVars: []string{"NTFY_UPSTREAM_BASE_URL"}, Value: "", Usage: "forward poll request to an upstream server, this is needed for iOS push notifications for self-hosted servers"}),
 	altsrc.NewStringFlag(&cli.StringFlag{Name: "smtp-sender-addr", Aliases: []string{"smtp_sender_addr"}, EnvVars: []string{"NTFY_SMTP_SENDER_ADDR"}, Usage: "SMTP server address (host:port) for outgoing emails"}),
 	altsrc.NewStringFlag(&cli.StringFlag{Name: "smtp-sender-user", Aliases: []string{"smtp_sender_user"}, EnvVars: []string{"NTFY_SMTP_SENDER_USER"}, Usage: "SMTP user (if e-mail sending is enabled)"}),
@@ -80,6 +80,8 @@ var flagsServe = append(
 	altsrc.NewIntFlag(&cli.IntFlag{Name: "visitor-email-limit-burst", Aliases: []string{"visitor_email_limit_burst"}, EnvVars: []string{"NTFY_VISITOR_EMAIL_LIMIT_BURST"}, Value: server.DefaultVisitorEmailLimitBurst, Usage: "initial limit of e-mails per visitor"}),
 	altsrc.NewDurationFlag(&cli.DurationFlag{Name: "visitor-email-limit-replenish", Aliases: []string{"visitor_email_limit_replenish"}, EnvVars: []string{"NTFY_VISITOR_EMAIL_LIMIT_REPLENISH"}, Value: server.DefaultVisitorEmailLimitReplenish, Usage: "interval at which burst limit is replenished (one per x)"}),
 	altsrc.NewBoolFlag(&cli.BoolFlag{Name: "behind-proxy", Aliases: []string{"behind_proxy", "P"}, EnvVars: []string{"NTFY_BEHIND_PROXY"}, Value: false, Usage: "if set, use X-Forwarded-For header to determine visitor IP address (for rate limiting)"}),
+	altsrc.NewStringFlag(&cli.StringFlag{Name: "stripe-key", Aliases: []string{"stripe_key"}, EnvVars: []string{"NTFY_STRIPE_KEY"}, Value: "", Usage: "xxxxxxxxxxxxx"}),
+	altsrc.NewStringFlag(&cli.StringFlag{Name: "stripe-webhook-key", Aliases: []string{"stripe_webhook_key"}, EnvVars: []string{"NTFY_STRIPE_WEBHOOK_KEY"}, Value: "", Usage: "xxxxxxxxxxxx"}),
 )
 
 var cmdServe = &cli.Command{
@@ -132,7 +134,6 @@ func execServe(c *cli.Context) error {
 	webRoot := c.String("web-root")
 	enableSignup := c.Bool("enable-signup")
 	enableLogin := c.Bool("enable-login")
-	enablePayments := c.Bool("enable-payments")
 	enableReservations := c.Bool("enable-reservations")
 	upstreamBaseURL := c.String("upstream-base-url")
 	smtpSenderAddr := c.String("smtp-sender-addr")
@@ -152,6 +153,8 @@ func execServe(c *cli.Context) error {
 	visitorEmailLimitBurst := c.Int("visitor-email-limit-burst")
 	visitorEmailLimitReplenish := c.Duration("visitor-email-limit-replenish")
 	behindProxy := c.Bool("behind-proxy")
+	stripeKey := c.String("stripe-key")
+	stripeWebhookKey := c.String("stripe-webhook-key")
 
 	// Check values
 	if firebaseKeyFile != "" && !util.FileExists(firebaseKeyFile) {
@@ -188,14 +191,17 @@ func execServe(c *cli.Context) error {
 		return errors.New("if upstream-base-url is set, base-url must also be set")
 	} else if upstreamBaseURL != "" && baseURL != "" && baseURL == upstreamBaseURL {
 		return errors.New("base-url and upstream-base-url cannot be identical, you'll likely want to set upstream-base-url to https://ntfy.sh, see https://ntfy.sh/docs/config/#ios-instant-notifications")
-	} else if authFile == "" && (enableSignup || enableLogin || enableReservations || enablePayments) {
-		return errors.New("cannot set enable-signup, enable-login, enable-reserve-topics, or enable-payments if auth-file is not set")
+	} else if authFile == "" && (enableSignup || enableLogin || enableReservations || stripeKey != "") {
+		return errors.New("cannot set enable-signup, enable-login, enable-reserve-topics, or stripe-key if auth-file is not set")
 	} else if enableSignup && !enableLogin {
 		return errors.New("cannot set enable-signup without also setting enable-login")
+	} else if stripeKey != "" && (stripeWebhookKey == "" || baseURL == "") {
+		return errors.New("if stripe-key is set, stripe-webhook-key and base-url must also be set")
 	}
 
 	webRootIsApp := webRoot == "app"
 	enableWeb := webRoot != "disable"
+	enablePayments := stripeKey != ""
 
 	// Default auth permissions
 	authDefault, err := user.ParsePermission(authDefaultAccess)
@@ -239,6 +245,11 @@ func execServe(c *cli.Context) error {
 		visitorRequestLimitExemptIPs = append(visitorRequestLimitExemptIPs, ips...)
 	}
 
+	// Stripe things
+	if stripeKey != "" {
+		stripe.Key = stripeKey
+	}
+
 	// Run server
 	conf := server.NewConfig()
 	conf.BaseURL = baseURL
@@ -282,6 +293,8 @@ func execServe(c *cli.Context) error {
 	conf.VisitorEmailLimitBurst = visitorEmailLimitBurst
 	conf.VisitorEmailLimitReplenish = visitorEmailLimitReplenish
 	conf.BehindProxy = behindProxy
+	conf.StripeKey = stripeKey
+	conf.StripeWebhookKey = stripeWebhookKey
 	conf.EnableWeb = enableWeb
 	conf.EnableSignup = enableSignup
 	conf.EnableLogin = enableLogin

+ 4 - 4
cmd/user.go

@@ -215,7 +215,7 @@ func execUserDel(c *cli.Context) error {
 	if err != nil {
 		return err
 	}
-	if _, err := manager.User(username); err == user.ErrNotFound {
+	if _, err := manager.User(username); err == user.ErrUserNotFound {
 		return fmt.Errorf("user %s does not exist", username)
 	}
 	if err := manager.RemoveUser(username); err != nil {
@@ -237,7 +237,7 @@ func execUserChangePass(c *cli.Context) error {
 	if err != nil {
 		return err
 	}
-	if _, err := manager.User(username); err == user.ErrNotFound {
+	if _, err := manager.User(username); err == user.ErrUserNotFound {
 		return fmt.Errorf("user %s does not exist", username)
 	}
 	if password == "" {
@@ -265,7 +265,7 @@ func execUserChangeRole(c *cli.Context) error {
 	if err != nil {
 		return err
 	}
-	if _, err := manager.User(username); err == user.ErrNotFound {
+	if _, err := manager.User(username); err == user.ErrUserNotFound {
 		return fmt.Errorf("user %s does not exist", username)
 	}
 	if err := manager.ChangeRole(username, role); err != nil {
@@ -289,7 +289,7 @@ func execUserChangeTier(c *cli.Context) error {
 	if err != nil {
 		return err
 	}
-	if _, err := manager.User(username); err == user.ErrNotFound {
+	if _, err := manager.User(username); err == user.ErrUserNotFound {
 		return fmt.Errorf("user %s does not exist", username)
 	}
 	if tier == tierReset {

+ 4 - 0
go.mod

@@ -46,6 +46,10 @@ require (
 	github.com/googleapis/gax-go/v2 v2.7.0 // indirect
 	github.com/pmezard/go-difflib v1.0.0 // indirect
 	github.com/russross/blackfriday/v2 v2.1.0 // indirect
+	github.com/stripe/stripe-go/v74 v74.5.0 // indirect
+	github.com/tidwall/gjson v1.14.4 // indirect
+	github.com/tidwall/match v1.1.1 // indirect
+	github.com/tidwall/pretty v1.2.1 // indirect
 	github.com/xrash/smetrics v0.0.0-20201216005158-039620a65673 // indirect
 	go.opencensus.io v0.24.0 // indirect
 	golang.org/x/net v0.4.0 // indirect

+ 12 - 0
go.sum

@@ -95,10 +95,20 @@ github.com/russross/blackfriday/v2 v2.1.0/go.mod h1:+Rmxgy9KzJVeS9/2gXHxylqXiyQD
 github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME=
 github.com/stretchr/objx v0.4.0/go.mod h1:YvHI0jy2hoMjB+UWwv71VJQ9isScKT/TqJzVSSt89Yw=
 github.com/stretchr/objx v0.5.0/go.mod h1:Yh+to48EsGEfYuaHDzXPcE3xhTkx73EhmCGUpEOglKo=
+github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg=
 github.com/stretchr/testify v1.7.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg=
 github.com/stretchr/testify v1.8.0/go.mod h1:yNjHg4UonilssWZ8iaSj1OCr/vHnekPRkoO+kdMU+MU=
 github.com/stretchr/testify v1.8.1 h1:w7B6lhMri9wdJUVmEZPGGhZzrYTPvgJArz7wNPgYKsk=
 github.com/stretchr/testify v1.8.1/go.mod h1:w2LPCIKwWwSfY2zedu0+kehJoqGctiVI29o6fzry7u4=
+github.com/stripe/stripe-go/v74 v74.5.0 h1:YyqTvVQdS34KYGCfVB87EMn9eDV3FCFkSwfdOQhiVL4=
+github.com/stripe/stripe-go/v74 v74.5.0/go.mod h1:5PoXNp30AJ3tGq57ZcFuaMylzNi8KpwlrYAFmO1fHZw=
+github.com/tidwall/gjson v1.14.4 h1:uo0p8EbA09J7RQaflQ1aBRffTR7xedD2bcIVSYxLnkM=
+github.com/tidwall/gjson v1.14.4/go.mod h1:/wbyibRr2FHMks5tjHJ5F8dMZh3AcwJEMf5vlfC0lxk=
+github.com/tidwall/match v1.1.1 h1:+Ho715JplO36QYgwN9PGYNhgZvoUSc9X2c80KVTi+GA=
+github.com/tidwall/match v1.1.1/go.mod h1:eRSPERbgtNPcGhD8UCthc6PmLEQXEWd3PRB5JTxsfmM=
+github.com/tidwall/pretty v1.2.0/go.mod h1:ITEVvHYasfjBbM0u2Pg8T2nJnzm8xPwvNhhsoaGGjNU=
+github.com/tidwall/pretty v1.2.1 h1:qjsOFOWWQl+N3RsoF5/ssm1pHmJJwhjlSbZ51I6wMl4=
+github.com/tidwall/pretty v1.2.1/go.mod h1:ITEVvHYasfjBbM0u2Pg8T2nJnzm8xPwvNhhsoaGGjNU=
 github.com/urfave/cli/v2 v2.23.7 h1:YHDQ46s3VghFHFf1DdF+Sh7H4RqhcM+t0TmZRJx4oJY=
 github.com/urfave/cli/v2 v2.23.7/go.mod h1:GHupkWPMM0M/sj1a2b4wUrWBPzazNrIjouW6fmdJLxc=
 github.com/xrash/smetrics v0.0.0-20201216005158-039620a65673 h1:bAn7/zixMGCfxrRTfdpNzjtPYqr8smhKouy9mxVdGPU=
@@ -119,6 +129,7 @@ golang.org/x/net v0.0.0-20190213061140-3a22650c66bd/go.mod h1:mL1N/T3taQHkDXs73r
 golang.org/x/net v0.0.0-20190311183353-d8887717615a/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg=
 golang.org/x/net v0.0.0-20190404232315-eb5bcb51f2a3/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg=
 golang.org/x/net v0.0.0-20190603091049-60506f45cf65/go.mod h1:HSz+uSET+XFnRR8LxR5pz3Of3rY3CfYBVs4xY44aLks=
+golang.org/x/net v0.0.0-20200324143707-d3edc9973b7e/go.mod h1:qpuaurCH72eLCgpAm/N6yyVIVM9cpaDIP3A8BGJEC5A=
 golang.org/x/net v0.0.0-20201110031124-69a78807bb2b/go.mod h1:sp8m0HH+o8qH0wwXwYZr8TS3Oi6o0r6Gce1SSxlDquU=
 golang.org/x/net v0.0.0-20220624214902-1bab6f366d9e/go.mod h1:XRhObCWvk6IyKnWLug+ECip1KBveYUHfp+8e9klMJ9c=
 golang.org/x/net v0.0.0-20220708220712-1185a9018129/go.mod h1:XRhObCWvk6IyKnWLug+ECip1KBveYUHfp+8e9klMJ9c=
@@ -135,6 +146,7 @@ golang.org/x/sync v0.1.0/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
 golang.org/x/sys v0.0.0-20180830151530-49385e6e1522/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
 golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
 golang.org/x/sys v0.0.0-20190412213103-97732733099d/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
+golang.org/x/sys v0.0.0-20200323222414-85ca7c5b95cd/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
 golang.org/x/sys v0.0.0-20200930185726-fdedc70b468f/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
 golang.org/x/sys v0.0.0-20210615035016-665e8c7367d1/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
 golang.org/x/sys v0.0.0-20220520151302-bc2c85ada10a/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=

+ 2 - 0
server/config.go

@@ -110,6 +110,8 @@ type Config struct {
 	VisitorAccountCreateLimitReplenish   time.Duration
 	VisitorStatsResetTime                time.Time // Time of the day at which to reset visitor stats
 	BehindProxy                          bool
+	StripeKey                            string
+	StripeWebhookKey                     string
 	EnableWeb                            bool
 	EnableSignup                         bool // Enable creation of accounts via API and UI
 	EnableLogin                          bool

+ 2 - 0
server/errors.go

@@ -58,6 +58,8 @@ var (
 	errHTTPBadRequestJSONInvalid                     = &errHTTP{40024, http.StatusBadRequest, "invalid request: request body must be valid JSON", ""}
 	errHTTPBadRequestPermissionInvalid               = &errHTTP{40025, http.StatusBadRequest, "invalid request: incorrect permission string", ""}
 	errHTTPBadRequestMakesNoSenseForAdmin            = &errHTTP{40026, http.StatusBadRequest, "invalid request: this makes no sense for admins", ""}
+	errHTTPBadRequestNotAPaidUser                    = &errHTTP{40027, http.StatusBadRequest, "invalid request: not a paid user", ""}
+	errHTTPBadRequestInvalidStripeRequest            = &errHTTP{40028, http.StatusBadRequest, "invalid request: not a valid Stripe request", ""}
 	errHTTPNotFound                                  = &errHTTP{40401, http.StatusNotFound, "page not found", ""}
 	errHTTPUnauthorized                              = &errHTTP{40101, http.StatusUnauthorized, "unauthorized", "https://ntfy.sh/docs/publish/#authentication"}
 	errHTTPForbidden                                 = &errHTTP{40301, http.StatusForbidden, "forbidden", "https://ntfy.sh/docs/publish/#authentication"}

+ 18 - 0
server/server.go

@@ -36,6 +36,10 @@ import (
 
 /*
 	TODO
+		payments:
+		- handle overdue payment (-> downgrade after 7 days)
+		- delete stripe subscription when acocunt is deleted
+
 		Limits & rate limiting:
 			users without tier: should the stats be persisted? are they meaningful?
 				-> test that the visitor is based on the IP address!
@@ -43,6 +47,7 @@ import (
 		update last_seen when API is accessed
 		Make sure account endpoints make sense for admins
 
+		triggerChange after publishing a message
 		UI:
 		- flicker of upgrade banner
 		- JS constants
@@ -100,6 +105,11 @@ var (
 	accountSettingsPath            = "/v1/account/settings"
 	accountSubscriptionPath        = "/v1/account/subscription"
 	accountReservationPath         = "/v1/account/reservation"
+	accountBillingPortalPath       = "/v1/account/billing/portal"
+	accountBillingWebhookPath      = "/v1/account/billing/webhook"
+	accountCheckoutPath            = "/v1/account/checkout"
+	accountCheckoutSuccessTemplate = "/v1/account/checkout/success/{CHECKOUT_SESSION_ID}"
+	accountCheckoutSuccessRegex    = regexp.MustCompile(`/v1/account/checkout/success/(.+)$`)
 	accountReservationSingleRegex  = regexp.MustCompile(`/v1/account/reservation/([-_A-Za-z0-9]{1,64})$`)
 	accountSubscriptionSingleRegex = regexp.MustCompile(`^/v1/account/subscription/([-_A-Za-z0-9]{16})$`)
 	matrixPushPath                 = "/_matrix/push/v1/notify"
@@ -362,6 +372,14 @@ func (s *Server) handleInternal(w http.ResponseWriter, r *http.Request, v *visit
 		return s.ensureUser(s.handleAccountReservationAdd)(w, r, v)
 	} else if r.Method == http.MethodDelete && accountReservationSingleRegex.MatchString(r.URL.Path) {
 		return s.ensureUser(s.handleAccountReservationDelete)(w, r, v)
+	} else if r.Method == http.MethodPost && r.URL.Path == accountCheckoutPath {
+		return s.ensureUser(s.handleAccountCheckoutSessionCreate)(w, r, v)
+	} else if r.Method == http.MethodGet && accountCheckoutSuccessRegex.MatchString(r.URL.Path) {
+		return s.ensureUserManager(s.handleAccountCheckoutSessionSuccessGet)(w, r, v) // No user context!
+	} else if r.Method == http.MethodPost && r.URL.Path == accountBillingPortalPath {
+		return s.ensureUser(s.handleAccountBillingPortalSessionCreate)(w, r, v)
+	} else if r.Method == http.MethodPost && r.URL.Path == accountBillingWebhookPath {
+		return s.ensureUserManager(s.handleAccountBillingWebhookTrigger)(w, r, v)
 	} else if r.Method == http.MethodGet && r.URL.Path == matrixPushPath {
 		return s.handleMatrixDiscovery(w)
 	} else if r.Method == http.MethodGet && staticRegex.MatchString(r.URL.Path) {

+ 232 - 0
server/server_account.go

@@ -2,6 +2,14 @@ package server
 
 import (
 	"encoding/json"
+	"errors"
+	"github.com/stripe/stripe-go/v74"
+	portalsession "github.com/stripe/stripe-go/v74/billingportal/session"
+	"github.com/stripe/stripe-go/v74/checkout/session"
+	"github.com/stripe/stripe-go/v74/subscription"
+	"github.com/stripe/stripe-go/v74/webhook"
+	"github.com/tidwall/gjson"
+	"heckel.io/ntfy/log"
 	"heckel.io/ntfy/user"
 	"heckel.io/ntfy/util"
 	"net/http"
@@ -9,6 +17,7 @@ import (
 
 const (
 	jsonBodyBytesLimit   = 4096
+	stripeBodyBytesLimit = 16384
 	subscriptionIDLength = 16
 	createdByAPI         = "api"
 )
@@ -386,3 +395,226 @@ func (s *Server) handleAccountReservationDelete(w http.ResponseWriter, r *http.R
 	w.Header().Set("Access-Control-Allow-Origin", "*") // FIXME remove this
 	return nil
 }
+
+func (s *Server) handleAccountCheckoutSessionCreate(w http.ResponseWriter, r *http.Request, v *visitor) error {
+	req, err := readJSONWithLimit[apiAccountTierChangeRequest](r.Body, jsonBodyBytesLimit)
+	if err != nil {
+		return err
+	}
+	tier, err := s.userManager.Tier(req.Tier)
+	if err != nil {
+		return err
+	}
+	if tier.StripePriceID == "" {
+		log.Info("Checkout: Downgrading to no tier")
+		return errors.New("not a paid tier")
+	} else if v.user.Billing != nil && v.user.Billing.StripeSubscriptionID != "" {
+		log.Info("Checkout: Changing tier and subscription to %s", tier.Code)
+
+		// Upgrade/downgrade tier
+		sub, err := subscription.Get(v.user.Billing.StripeSubscriptionID, nil)
+		if err != nil {
+			return err
+		}
+		params := &stripe.SubscriptionParams{
+			CancelAtPeriodEnd: stripe.Bool(false),
+			ProrationBehavior: stripe.String(string(stripe.SubscriptionSchedulePhaseProrationBehaviorCreateProrations)),
+			Items: []*stripe.SubscriptionItemsParams{
+				{
+					ID:    stripe.String(sub.Items.Data[0].ID),
+					Price: stripe.String(tier.StripePriceID),
+				},
+			},
+		}
+		_, err = subscription.Update(sub.ID, params)
+		if err != nil {
+			return err
+		}
+		response := &apiAccountCheckoutResponse{}
+		w.Header().Set("Content-Type", "application/json")
+		w.Header().Set("Access-Control-Allow-Origin", "*") // FIXME remove this
+		if err := json.NewEncoder(w).Encode(response); err != nil {
+			return err
+		}
+		return nil
+	} else {
+		// Checkout flow
+		log.Info("Checkout: No existing subscription, creating checkout flow")
+	}
+
+	successURL := s.config.BaseURL + accountCheckoutSuccessTemplate
+	var stripeCustomerID *string
+	if v.user.Billing != nil {
+		stripeCustomerID = &v.user.Billing.StripeCustomerID
+	}
+	params := &stripe.CheckoutSessionParams{
+		ClientReferenceID: &v.user.Name, // FIXME Should be user ID
+		Customer:          stripeCustomerID,
+		SuccessURL:        &successURL,
+		Mode:              stripe.String(string(stripe.CheckoutSessionModeSubscription)),
+		LineItems: []*stripe.CheckoutSessionLineItemParams{
+			{
+				Price:    stripe.String(tier.StripePriceID),
+				Quantity: stripe.Int64(1),
+			},
+		},
+	}
+	sess, err := session.New(params)
+	if err != nil {
+		return err
+	}
+	response := &apiAccountCheckoutResponse{
+		RedirectURL: sess.URL,
+	}
+	w.Header().Set("Content-Type", "application/json")
+	w.Header().Set("Access-Control-Allow-Origin", "*") // FIXME remove this
+	if err := json.NewEncoder(w).Encode(response); err != nil {
+		return err
+	}
+	return nil
+}
+
+func (s *Server) handleAccountCheckoutSessionSuccessGet(w http.ResponseWriter, r *http.Request, v *visitor) error {
+	// We don't have a v.user in this endpoint, only a userManager!
+	matches := accountCheckoutSuccessRegex.FindStringSubmatch(r.URL.Path)
+	if len(matches) != 2 {
+		return errHTTPInternalErrorInvalidPath
+	}
+	sessionID := matches[1]
+	// FIXME how do I rate limit this?
+	sess, err := session.Get(sessionID, nil)
+	if err != nil {
+		log.Warn("Stripe: %s", err)
+		return errHTTPBadRequestInvalidStripeRequest
+	} else if sess.Customer == nil || sess.Subscription == nil || sess.ClientReferenceID == "" {
+		log.Warn("Stripe: Unexpected session, customer or subscription not found")
+		return errHTTPBadRequestInvalidStripeRequest
+	}
+	sub, err := subscription.Get(sess.Subscription.ID, nil)
+	if err != nil {
+		return err
+	} else if sub.Items == nil || len(sub.Items.Data) != 1 || sub.Items.Data[0].Price == nil {
+		log.Error("Stripe: Unexpected subscription, expected exactly one line item")
+		return errHTTPBadRequestInvalidStripeRequest
+	}
+	priceID := sub.Items.Data[0].Price.ID
+	tier, err := s.userManager.TierByStripePrice(priceID)
+	if err != nil {
+		return err
+	}
+	u, err := s.userManager.User(sess.ClientReferenceID)
+	if err != nil {
+		return err
+	}
+	if u.Billing == nil {
+		u.Billing = &user.Billing{}
+	}
+	u.Billing.StripeCustomerID = sess.Customer.ID
+	u.Billing.StripeSubscriptionID = sess.Subscription.ID
+	if err := s.userManager.ChangeBilling(u); err != nil {
+		return err
+	}
+	if err := s.userManager.ChangeTier(u.Name, tier.Code); err != nil {
+		return err
+	}
+	accountURL := s.config.BaseURL + "/account" // FIXME
+	http.Redirect(w, r, accountURL, http.StatusSeeOther)
+	return nil
+}
+
+func (s *Server) handleAccountBillingPortalSessionCreate(w http.ResponseWriter, r *http.Request, v *visitor) error {
+	if v.user.Billing == nil {
+		return errHTTPBadRequestNotAPaidUser
+	}
+	params := &stripe.BillingPortalSessionParams{
+		Customer:  stripe.String(v.user.Billing.StripeCustomerID),
+		ReturnURL: stripe.String(s.config.BaseURL),
+	}
+	ps, err := portalsession.New(params)
+	if err != nil {
+		return err
+	}
+	response := &apiAccountBillingPortalRedirectResponse{
+		RedirectURL: ps.URL,
+	}
+	w.Header().Set("Content-Type", "application/json")
+	w.Header().Set("Access-Control-Allow-Origin", "*") // FIXME remove this
+	if err := json.NewEncoder(w).Encode(response); err != nil {
+		return err
+	}
+	return nil
+}
+
+func (s *Server) handleAccountBillingWebhookTrigger(w http.ResponseWriter, r *http.Request, v *visitor) error {
+	// We don't have a v.user in this endpoint, only a userManager!
+	stripeSignature := r.Header.Get("Stripe-Signature")
+	if stripeSignature == "" {
+		return errHTTPBadRequestInvalidStripeRequest
+	}
+	body, err := util.Peek(r.Body, stripeBodyBytesLimit)
+	if err != nil {
+		return err
+	} else if body.LimitReached {
+		return errHTTPEntityTooLargeJSONBody
+	}
+	event, err := webhook.ConstructEvent(body.PeekedBytes, stripeSignature, s.config.StripeWebhookKey)
+	if err != nil {
+		log.Warn("Stripe: invalid request: %s", err.Error())
+		return errHTTPBadRequestInvalidStripeRequest
+	} else if event.Data == nil || event.Data.Raw == nil {
+		log.Warn("Stripe: invalid request, data is nil")
+		return errHTTPBadRequestInvalidStripeRequest
+	}
+	log.Info("Stripe: webhook event %s received", event.Type)
+	stripeCustomerID := gjson.GetBytes(event.Data.Raw, "customer")
+	if !stripeCustomerID.Exists() {
+		return errHTTPBadRequestInvalidStripeRequest
+	}
+	switch event.Type {
+	case "checkout.session.completed":
+		// Payment is successful and the subscription is created.
+		// Provision the subscription, save the customer ID.
+		return s.handleAccountBillingWebhookCheckoutCompleted(stripeCustomerID.String(), event.Data.Raw)
+	case "customer.subscription.updated":
+		return s.handleAccountBillingWebhookSubscriptionUpdated(stripeCustomerID.String(), event.Data.Raw)
+	case "invoice.paid":
+		// Continue to provision the subscription as payments continue to be made.
+		// Store the status in your database and check when a user accesses your service.
+		// This approach helps you avoid hitting rate limits.
+		return nil // FIXME
+	case "invoice.payment_failed":
+		// The payment failed or the customer does not have a valid payment method.
+		// The subscription becomes past_due. Notify your customer and send them to the
+		// customer portal to update their payment information.
+		return nil // FIXME
+	default:
+		log.Warn("Stripe: unhandled webhook %s", event.Type)
+		return nil
+	}
+}
+
+func (s *Server) handleAccountBillingWebhookCheckoutCompleted(stripeCustomerID string, event json.RawMessage) error {
+	log.Info("Stripe: checkout completed for customer %s", stripeCustomerID)
+	return nil
+}
+
+func (s *Server) handleAccountBillingWebhookSubscriptionUpdated(stripeCustomerID string, event json.RawMessage) error {
+	status := gjson.GetBytes(event, "status")
+	priceID := gjson.GetBytes(event, "items.data.0.price.id")
+	if !status.Exists() || !priceID.Exists() {
+		return errHTTPBadRequestInvalidStripeRequest
+	}
+	log.Info("Stripe: customer %s: subscription updated to %s, with price %s", stripeCustomerID, status, priceID)
+	u, err := s.userManager.UserByStripeCustomer(stripeCustomerID)
+	if err != nil {
+		return err
+	}
+	tier, err := s.userManager.TierByStripePrice(priceID.String())
+	if err != nil {
+		return err
+	}
+	if err := s.userManager.ChangeTier(u.Name, tier.Code); err != nil {
+		return err
+	}
+	return nil
+}

+ 12 - 0
server/types.go

@@ -295,3 +295,15 @@ type apiConfigResponse struct {
 	EnableReservations bool     `json:"enable_reservations"`
 	DisallowedTopics   []string `json:"disallowed_topics"`
 }
+
+type apiAccountTierChangeRequest struct {
+	Tier string `json:"tier"`
+}
+
+type apiAccountCheckoutResponse struct {
+	RedirectURL string `json:"redirect_url"`
+}
+
+type apiAccountBillingPortalRedirectResponse struct {
+	RedirectURL string `json:"redirect_url"`
+}

+ 102 - 8
user/manager.go

@@ -44,8 +44,11 @@ const (
 			reservations_limit INT NOT NULL,
 			attachment_file_size_limit INT NOT NULL,
 			attachment_total_size_limit INT NOT NULL,
-			attachment_expiry_duration INT NOT NULL
+			attachment_expiry_duration INT NOT NULL,
+			stripe_price_id TEXT
 		);
+		CREATE UNIQUE INDEX idx_tier_code ON tier (code);
+		CREATE UNIQUE INDEX idx_tier_price_id ON tier (stripe_price_id);
 		CREATE TABLE IF NOT EXISTS user (
 		    id INTEGER PRIMARY KEY AUTOINCREMENT,
 			tier_id INT,
@@ -56,12 +59,16 @@ const (
 			sync_topic TEXT NOT NULL,
 			stats_messages INT NOT NULL DEFAULT (0),
 			stats_emails INT NOT NULL DEFAULT (0),
+			stripe_customer_id TEXT,
+			stripe_subscription_id TEXT,			
 			created_by TEXT NOT NULL,
 			created_at INT NOT NULL,
 			last_seen INT NOT NULL,
 		    FOREIGN KEY (tier_id) REFERENCES tier (id)
 		);
 		CREATE UNIQUE INDEX idx_user ON user (user);
+		CREATE UNIQUE INDEX idx_user_stripe_customer_id ON user (stripe_customer_id);
+		CREATE UNIQUE INDEX idx_user_stripe_subscription_id ON user (stripe_subscription_id);
 		CREATE TABLE IF NOT EXISTS user_access (
 			user_id INT NOT NULL,
 			topic TEXT NOT NULL,
@@ -93,18 +100,24 @@ const (
 	`
 
 	selectUserByNameQuery = `
-		SELECT u.user, u.pass, u.role, u.prefs, u.sync_topic, u.stats_messages, u.stats_emails, p.code, p.name, p.paid, p.messages_limit, p.messages_expiry_duration, p.emails_limit, p.reservations_limit, p.attachment_file_size_limit, p.attachment_total_size_limit, p.attachment_expiry_duration
+		SELECT u.user, u.pass, u.role, u.prefs, u.sync_topic, u.stats_messages, u.stats_emails, u.stripe_customer_id, u.stripe_subscription_id, p.code, p.name, p.paid, p.messages_limit, p.messages_expiry_duration, p.emails_limit, p.reservations_limit, p.attachment_file_size_limit, p.attachment_total_size_limit, p.attachment_expiry_duration, p.stripe_price_id
 		FROM user u
 		LEFT JOIN tier p on p.id = u.tier_id
 		WHERE user = ?		
 	`
 	selectUserByTokenQuery = `
-		SELECT u.user, u.pass, u.role, u.prefs, u.sync_topic, u.stats_messages, u.stats_emails, p.code, p.name, p.paid, p.messages_limit, p.messages_expiry_duration, p.emails_limit, p.reservations_limit, p.attachment_file_size_limit, p.attachment_total_size_limit, p.attachment_expiry_duration
+		SELECT u.user, u.pass, u.role, u.prefs, u.sync_topic, u.stats_messages, u.stats_emails, u.stripe_customer_id, u.stripe_subscription_id, p.code, p.name, p.paid, p.messages_limit, p.messages_expiry_duration, p.emails_limit, p.reservations_limit, p.attachment_file_size_limit, p.attachment_total_size_limit, p.attachment_expiry_duration, p.stripe_price_id
 		FROM user u
 		JOIN user_token t on u.id = t.user_id
 		LEFT JOIN tier p on p.id = u.tier_id
 		WHERE t.token = ? AND t.expires >= ?
 	`
+	selectUserByStripeCustomerIDQuery = `
+		SELECT u.user, u.pass, u.role, u.prefs, u.sync_topic, u.stats_messages, u.stats_emails, u.stripe_customer_id, u.stripe_subscription_id, p.code, p.name, p.paid, p.messages_limit, p.messages_expiry_duration, p.emails_limit, p.reservations_limit, p.attachment_file_size_limit, p.attachment_total_size_limit, p.attachment_expiry_duration, p.stripe_price_id
+		FROM user u
+		LEFT JOIN tier p on p.id = u.tier_id
+		WHERE u.stripe_customer_id = ?
+	`
 	selectTopicPermsQuery = `
 		SELECT read, write
 		FROM user_access a
@@ -204,9 +217,21 @@ const (
 		INSERT INTO tier (code, name, paid, messages_limit, messages_expiry_duration, emails_limit, reservations_limit, attachment_file_size_limit, attachment_total_size_limit, attachment_expiry_duration)
 		VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
 	`
-	selectTierIDQuery   = `SELECT id FROM tier WHERE code = ?`
+	selectTierIDQuery     = `SELECT id FROM tier WHERE code = ?`
+	selectTierByCodeQuery = `
+		SELECT code, name, paid, messages_limit, messages_expiry_duration, emails_limit, reservations_limit, attachment_file_size_limit, attachment_total_size_limit, attachment_expiry_duration, stripe_price_id
+		FROM tier
+		WHERE code = ?
+	`
+	selectTierByPriceIDQuery = `
+		SELECT code, name, paid, messages_limit, messages_expiry_duration, emails_limit, reservations_limit, attachment_file_size_limit, attachment_total_size_limit, attachment_expiry_duration, stripe_price_id
+		FROM tier
+		WHERE stripe_price_id = ?
+	`
 	updateUserTierQuery = `UPDATE user SET tier_id = ? WHERE user = ?`
 	deleteUserTierQuery = `UPDATE user SET tier_id = null WHERE user = ?`
+
+	updateBillingQuery = `UPDATE user SET stripe_customer_id = ?, stripe_subscription_id = ? WHERE user = ?`
 )
 
 // Schema management queries
@@ -543,7 +568,7 @@ func (a *Manager) Users() ([]*User, error) {
 	return users, nil
 }
 
-// User returns the user with the given username if it exists, or ErrNotFound otherwise.
+// User returns the user with the given username if it exists, or ErrUserNotFound otherwise.
 // You may also pass Everyone to retrieve the anonymous user and its Grant list.
 func (a *Manager) User(username string) (*User, error) {
 	rows, err := a.db.Query(selectUserByNameQuery, username)
@@ -553,6 +578,14 @@ func (a *Manager) User(username string) (*User, error) {
 	return a.readUser(rows)
 }
 
+func (a *Manager) UserByStripeCustomer(stripeCustomerID string) (*User, error) {
+	rows, err := a.db.Query(selectUserByStripeCustomerIDQuery, stripeCustomerID)
+	if err != nil {
+		return nil, err
+	}
+	return a.readUser(rows)
+}
+
 func (a *Manager) userByToken(token string) (*User, error) {
 	rows, err := a.db.Query(selectUserByTokenQuery, token, time.Now().Unix())
 	if err != nil {
@@ -564,14 +597,14 @@ func (a *Manager) userByToken(token string) (*User, error) {
 func (a *Manager) readUser(rows *sql.Rows) (*User, error) {
 	defer rows.Close()
 	var username, hash, role, prefs, syncTopic string
-	var tierCode, tierName sql.NullString
+	var stripeCustomerID, stripeSubscriptionID, stripePriceID, tierCode, tierName sql.NullString
 	var paid sql.NullBool
 	var messages, emails int64
 	var messagesLimit, messagesExpiryDuration, emailsLimit, reservationsLimit, attachmentFileSizeLimit, attachmentTotalSizeLimit, attachmentExpiryDuration sql.NullInt64
 	if !rows.Next() {
-		return nil, ErrNotFound
+		return nil, ErrUserNotFound
 	}
-	if err := rows.Scan(&username, &hash, &role, &prefs, &syncTopic, &messages, &emails, &tierCode, &tierName, &paid, &messagesLimit, &messagesExpiryDuration, &emailsLimit, &reservationsLimit, &attachmentFileSizeLimit, &attachmentTotalSizeLimit, &attachmentExpiryDuration); err != nil {
+	if err := rows.Scan(&username, &hash, &role, &prefs, &syncTopic, &messages, &emails, &stripeCustomerID, &stripeSubscriptionID, &tierCode, &tierName, &paid, &messagesLimit, &messagesExpiryDuration, &emailsLimit, &reservationsLimit, &attachmentFileSizeLimit, &attachmentTotalSizeLimit, &attachmentExpiryDuration, &stripePriceID); err != nil {
 		return nil, err
 	} else if err := rows.Err(); err != nil {
 		return nil, err
@@ -590,7 +623,14 @@ func (a *Manager) readUser(rows *sql.Rows) (*User, error) {
 	if err := json.Unmarshal([]byte(prefs), user.Prefs); err != nil {
 		return nil, err
 	}
+	if stripeCustomerID.Valid && stripeSubscriptionID.Valid {
+		user.Billing = &Billing{
+			StripeCustomerID:     stripeCustomerID.String,
+			StripeSubscriptionID: stripeSubscriptionID.String,
+		}
+	}
 	if tierCode.Valid {
+		// See readTier() when this is changed!
 		user.Tier = &Tier{
 			Code:                     tierCode.String,
 			Name:                     tierName.String,
@@ -602,6 +642,7 @@ func (a *Manager) readUser(rows *sql.Rows) (*User, error) {
 			AttachmentFileSizeLimit:  attachmentFileSizeLimit.Int64,
 			AttachmentTotalSizeLimit: attachmentTotalSizeLimit.Int64,
 			AttachmentExpiryDuration: time.Duration(attachmentExpiryDuration.Int64) * time.Second,
+			StripePriceID:            stripePriceID.String,
 		}
 	}
 	return user, nil
@@ -826,6 +867,59 @@ func (a *Manager) CreateTier(tier *Tier) error {
 	return nil
 }
 
+func (a *Manager) ChangeBilling(user *User) error {
+	if _, err := a.db.Exec(updateBillingQuery, user.Billing.StripeCustomerID, user.Billing.StripeSubscriptionID, user.Name); err != nil {
+		return err
+	}
+	return nil
+}
+
+func (a *Manager) Tier(code string) (*Tier, error) {
+	rows, err := a.db.Query(selectTierByCodeQuery, code)
+	if err != nil {
+		return nil, err
+	}
+	return a.readTier(rows)
+}
+
+func (a *Manager) TierByStripePrice(priceID string) (*Tier, error) {
+	rows, err := a.db.Query(selectTierByPriceIDQuery, priceID)
+	if err != nil {
+		return nil, err
+	}
+	return a.readTier(rows)
+}
+
+func (a *Manager) readTier(rows *sql.Rows) (*Tier, error) {
+	defer rows.Close()
+	var code, name string
+	var stripePriceID sql.NullString
+	var paid bool
+	var messagesLimit, messagesExpiryDuration, emailsLimit, reservationsLimit, attachmentFileSizeLimit, attachmentTotalSizeLimit, attachmentExpiryDuration sql.NullInt64
+	if !rows.Next() {
+		return nil, ErrTierNotFound
+	}
+	if err := rows.Scan(&code, &name, &paid, &messagesLimit, &messagesExpiryDuration, &emailsLimit, &reservationsLimit, &attachmentFileSizeLimit, &attachmentTotalSizeLimit, &attachmentExpiryDuration, &stripePriceID); err != nil {
+		return nil, err
+	} else if err := rows.Err(); err != nil {
+		return nil, err
+	}
+	// When changed, note readUser() as well
+	return &Tier{
+		Code:                     code,
+		Name:                     name,
+		Paid:                     paid,
+		MessagesLimit:            messagesLimit.Int64,
+		MessagesExpiryDuration:   time.Duration(messagesExpiryDuration.Int64) * time.Second,
+		EmailsLimit:              emailsLimit.Int64,
+		ReservationsLimit:        reservationsLimit.Int64,
+		AttachmentFileSizeLimit:  attachmentFileSizeLimit.Int64,
+		AttachmentTotalSizeLimit: attachmentTotalSizeLimit.Int64,
+		AttachmentExpiryDuration: time.Duration(attachmentExpiryDuration.Int64) * time.Second,
+		StripePriceID:            stripePriceID.String, // May be empty!
+	}, nil
+}
+
 func toSQLWildcard(s string) string {
 	return strings.ReplaceAll(s, "*", "%")
 }

+ 1 - 1
user/manager_test.go

@@ -208,7 +208,7 @@ func TestManager_UserManagement(t *testing.T) {
 	// Remove user
 	require.Nil(t, a.RemoveUser("ben"))
 	_, err = a.User("ben")
-	require.Equal(t, ErrNotFound, err)
+	require.Equal(t, ErrUserNotFound, err)
 
 	users, err = a.Users()
 	require.Nil(t, err)

+ 10 - 1
user/types.go

@@ -16,6 +16,7 @@ type User struct {
 	Prefs     *Prefs
 	Tier      *Tier
 	Stats     *Stats
+	Billing   *Billing
 	SyncTopic string
 	Created   time.Time
 	LastSeen  time.Time
@@ -58,6 +59,7 @@ type Tier struct {
 	AttachmentFileSizeLimit  int64
 	AttachmentTotalSizeLimit int64
 	AttachmentExpiryDuration time.Duration
+	StripePriceID            string
 }
 
 // Subscription represents a user's topic subscription
@@ -81,6 +83,12 @@ type Stats struct {
 	Emails   int64
 }
 
+// Billing is a struct holding a user's billing information
+type Billing struct {
+	StripeCustomerID     string
+	StripeSubscriptionID string
+}
+
 // Grant is a struct that represents an access control entry to a topic by a user
 type Grant struct {
 	TopicPattern string // May include wildcard (*)
@@ -212,5 +220,6 @@ var (
 	ErrUnauthenticated = errors.New("unauthenticated")
 	ErrUnauthorized    = errors.New("unauthorized")
 	ErrInvalidArgument = errors.New("invalid argument")
-	ErrNotFound        = errors.New("not found")
+	ErrUserNotFound    = errors.New("user not found")
+	ErrTierNotFound    = errors.New("tier not found")
 )

+ 36 - 3
web/src/app/AccountApi.js

@@ -8,7 +8,7 @@ import {
     accountTokenUrl,
     accountUrl, maybeWithAuth, topicUrl,
     withBasicAuth,
-    withBearerAuth
+    withBearerAuth, accountCheckoutUrl, accountBillingPortalUrl
 } from "./utils";
 import session from "./Session";
 import subscriptionManager from "./SubscriptionManager";
@@ -228,7 +228,7 @@ class AccountApi {
         this.triggerChange(); // Dangle!
     }
 
-    async upsertAccess(topic, everyone) {
+    async upsertReservation(topic, everyone) {
         const url = accountReservationUrl(config.base_url);
         console.log(`[AccountApi] Upserting user access to topic ${topic}, everyone=${everyone}`);
         const response = await fetch(url, {
@@ -249,7 +249,7 @@ class AccountApi {
         this.triggerChange(); // Dangle!
     }
 
-    async deleteAccess(topic) {
+    async deleteReservation(topic) {
         const url = accountReservationSingleUrl(config.base_url, topic);
         console.log(`[AccountApi] Removing topic reservation ${url}`);
         const response = await fetch(url, {
@@ -264,6 +264,39 @@ class AccountApi {
         this.triggerChange(); // Dangle!
     }
 
+    async createCheckoutSession(tier) {
+        const url = accountCheckoutUrl(config.base_url);
+        console.log(`[AccountApi] Creating checkout session`);
+        const response = await fetch(url, {
+            method: "POST",
+            headers: withBearerAuth({}, session.token()),
+            body: JSON.stringify({
+                tier: tier
+            })
+        });
+        if (response.status === 401 || response.status === 403) {
+            throw new UnauthorizedError();
+        } else if (response.status !== 200) {
+            throw new Error(`Unexpected server response ${response.status}`);
+        }
+        return await response.json();
+    }
+
+    async createBillingPortalSession() {
+        const url = accountBillingPortalUrl(config.base_url);
+        console.log(`[AccountApi] Creating billing portal session`);
+        const response = await fetch(url, {
+            method: "POST",
+            headers: withBearerAuth({}, session.token())
+        });
+        if (response.status === 401 || response.status === 403) {
+            throw new UnauthorizedError();
+        } else if (response.status !== 200) {
+            throw new Error(`Unexpected server response ${response.status}`);
+        }
+        return await response.json();
+    }
+
     async sync() {
         try {
             if (!session.token()) {

+ 2 - 0
web/src/app/utils.js

@@ -26,6 +26,8 @@ export const accountSubscriptionUrl = (baseUrl) => `${baseUrl}/v1/account/subscr
 export const accountSubscriptionSingleUrl = (baseUrl, id) => `${baseUrl}/v1/account/subscription/${id}`;
 export const accountReservationUrl = (baseUrl) => `${baseUrl}/v1/account/reservation`;
 export const accountReservationSingleUrl = (baseUrl, topic) => `${baseUrl}/v1/account/reservation/${topic}`;
+export const accountCheckoutUrl = (baseUrl) => `${baseUrl}/v1/account/checkout`;
+export const accountBillingPortalUrl = (baseUrl) => `${baseUrl}/v1/account/billing/portal`;
 export const shortUrl = (url) => url.replaceAll(/https?:\/\//g, "");
 export const expandUrl = (url) => [`https://${url}`, `http://${url}`];
 export const expandSecureUrl = (url) => `https://${url}`;

+ 33 - 7
web/src/components/Account.js

@@ -171,10 +171,28 @@ const Stats = () => {
     const { t } = useTranslation();
     const { account } = useContext(AccountContext);
     const [upgradeDialogOpen, setUpgradeDialogOpen] = useState(false);
+
     if (!account) {
         return <></>;
     }
-    const normalize = (value, max) => Math.min(value / max * 100, 100);
+
+    const normalize = (value, max) => {
+        return Math.min(value / max * 100, 100);
+    };
+
+    const handleManageBilling = async () => {
+        try {
+            const response = await accountApi.createBillingPortalSession();
+            window.location.href = response.redirect_url;
+        } catch (e) {
+            console.log(`[Account] Error changing password`, e);
+            if ((e instanceof UnauthorizedError)) {
+                session.resetAndRedirect(routes.login);
+            }
+            // TODO show error
+        }
+    };
+
     return (
         <Card sx={{p: 3}} aria-label={t("account_usage_title")}>
             <Typography variant="h5" sx={{marginBottom: 2}}>
@@ -201,12 +219,20 @@ const Stats = () => {
                             >{t("account_usage_tier_upgrade_button")}</Button>
                         }
                         {config.enable_payments && account.role === "user" && account.tier?.paid &&
-                            <Button
-                                variant="outlined"
-                                size="small"
-                                onClick={() => setUpgradeDialogOpen(true)}
-                                sx={{ml: 1}}
-                            >{t("account_usage_tier_change_button")}</Button>
+                            <>
+                                <Button
+                                    variant="outlined"
+                                    size="small"
+                                    onClick={() => setUpgradeDialogOpen(true)}
+                                    sx={{ml: 1}}
+                                >{t("account_usage_tier_change_button")}</Button>
+                                <Button
+                                    variant="outlined"
+                                    size="small"
+                                    onClick={handleManageBilling}
+                                    sx={{ml: 1}}
+                                >Manage billing</Button>
+                            </>
                         }
                         <UpgradeDialog
                             open={upgradeDialogOpen}

+ 3 - 3
web/src/components/Preferences.js

@@ -501,7 +501,7 @@ const Reservations = () => {
     const handleDialogSubmit = async (reservation) => {
         setDialogOpen(false);
         try {
-            await accountApi.upsertAccess(reservation.topic, reservation.everyone);
+            await accountApi.upsertReservation(reservation.topic, reservation.everyone);
             await accountApi.sync();
             console.debug(`[Preferences] Added topic reservation`, reservation);
         } catch (e) {
@@ -557,7 +557,7 @@ const ReservationsTable = (props) => {
     const handleDialogSubmit = async (reservation) => {
         setDialogOpen(false);
         try {
-            await accountApi.upsertAccess(reservation.topic, reservation.everyone);
+            await accountApi.upsertReservation(reservation.topic, reservation.everyone);
             await accountApi.sync();
             console.debug(`[Preferences] Added topic reservation`, reservation);
         } catch (e) {
@@ -568,7 +568,7 @@ const ReservationsTable = (props) => {
 
     const handleDeleteClick = async (reservation) => {
         try {
-            await accountApi.deleteAccess(reservation.topic);
+            await accountApi.deleteReservation(reservation.topic);
             await accountApi.sync();
             console.debug(`[Preferences] Deleted topic reservation`, reservation);
         } catch (e) {

+ 1 - 1
web/src/components/SubscribeDialog.js

@@ -110,7 +110,7 @@ const SubscribePage = (props) => {
         if (session.exists() && baseUrl === config.base_url && reserveTopicVisible) {
             console.log(`[SubscribeDialog] Reserving topic ${topic} with everyone access ${everyone}`);
             try {
-                await accountApi.upsertAccess(topic, everyone);
+                await accountApi.upsertReservation(topic, everyone);
                 // Account sync later after it was added
             } catch (e) {
                 console.log(`[SubscribeDialog] Error reserving topic`, e);

+ 2 - 2
web/src/components/SubscriptionSettingsDialog.js

@@ -37,9 +37,9 @@ const SubscriptionSettingsDialog = (props) => {
 
                 // Reservation
                 if (reserveTopicVisible) {
-                    await accountApi.upsertAccess(subscription.topic, everyone);
+                    await accountApi.upsertReservation(subscription.topic, everyone);
                 } else if (!reserveTopicVisible && subscription.reservation) { // Was removed
-                    await accountApi.deleteAccess(subscription.topic);
+                    await accountApi.deleteReservation(subscription.topic);
                 }
 
                 // Sync account

+ 62 - 7
web/src/components/UpgradeDialog.js

@@ -2,28 +2,83 @@ import * as React from 'react';
 import Dialog from '@mui/material/Dialog';
 import DialogContent from '@mui/material/DialogContent';
 import DialogTitle from '@mui/material/DialogTitle';
-import {useMediaQuery} from "@mui/material";
+import {CardActionArea, CardContent, useMediaQuery} from "@mui/material";
 import theme from "./theme";
 import DialogFooter from "./DialogFooter";
+import Button from "@mui/material/Button";
+import accountApi, {TopicReservedError, UnauthorizedError} from "../app/AccountApi";
+import session from "../app/Session";
+import routes from "./routes";
+import {useContext, useState} from "react";
+import Card from "@mui/material/Card";
+import Typography from "@mui/material/Typography";
+import {AccountContext} from "./App";
 
 const UpgradeDialog = (props) => {
+    const { account } = useContext(AccountContext);
     const fullScreen = useMediaQuery(theme.breakpoints.down('sm'));
+    const [selected, setSelected] = useState(account?.tier?.code || null);
+    const [errorText, setErrorText] = useState("");
 
-    const handleSuccess = async () => {
-        // TODO
+    const handleCheckout = async () => {
+        try {
+            const response = await accountApi.createCheckoutSession(selected);
+            if (response.redirect_url) {
+                window.location.href = response.redirect_url;
+            } else {
+                await accountApi.sync();
+            }
+
+        } catch (e) {
+            console.log(`[UpgradeDialog] Error creating checkout session`, e);
+            if ((e instanceof UnauthorizedError)) {
+                session.resetAndRedirect(routes.login);
+            }
+            // FIXME show error
+        }
     }
 
     return (
-        <Dialog open={props.open} onClose={props.onCancel} fullScreen={fullScreen}>
+        <Dialog open={props.open} onClose={props.onCancel} maxWidth="md" fullScreen={fullScreen}>
             <DialogTitle>Upgrade to Pro</DialogTitle>
             <DialogContent>
-                Content
+                <div style={{
+                    display: "flex",
+                    flexDirection: "row"
+                }}>
+                    <TierCard code={null} name={"Free"} selected={selected === null} onClick={() => setSelected(null)}/>
+                    <TierCard code="starter" name={"Starter"} selected={selected === "starter"} onClick={() => setSelected("starter")}/>
+                    <TierCard code="pro" name={"Pro"} selected={selected === "pro"} onClick={() => setSelected("pro")}/>
+                    <TierCard code="business" name={"Business"} selected={selected === "business"} onClick={() => setSelected("business")}/>
+                </div>
             </DialogContent>
-            <DialogFooter>
-                Footer
+            <DialogFooter status={errorText}>
+                <Button onClick={handleCheckout}>Checkout</Button>
             </DialogFooter>
         </Dialog>
     );
 };
 
+const TierCard = (props) => {
+    const cardStyle = (props.selected) ? {
+        border: "1px solid red",
+
+    } : {};
+    return (
+        <Card sx={{ m: 1, maxWidth: 345 }}>
+            <CardActionArea>
+                <CardContent sx={{...cardStyle}} onClick={props.onClick}>
+                    <Typography gutterBottom variant="h5" component="div">
+                        {props.name}
+                    </Typography>
+                    <Typography variant="body2" color="text.secondary">
+                        Lizards are a widespread group of squamate reptiles, with over 6,000
+                        species, ranging across all continents except Antarctica
+                    </Typography>
+                </CardContent>
+            </CardActionArea>
+        </Card>
+    );
+}
+
 export default UpgradeDialog;