server_payments.go 19 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485
  1. package server
  2. import (
  3. "bytes"
  4. "encoding/json"
  5. "errors"
  6. "fmt"
  7. "github.com/stripe/stripe-go/v74"
  8. portalsession "github.com/stripe/stripe-go/v74/billingportal/session"
  9. "github.com/stripe/stripe-go/v74/checkout/session"
  10. "github.com/stripe/stripe-go/v74/customer"
  11. "github.com/stripe/stripe-go/v74/price"
  12. "github.com/stripe/stripe-go/v74/subscription"
  13. "github.com/stripe/stripe-go/v74/webhook"
  14. "heckel.io/ntfy/log"
  15. "heckel.io/ntfy/user"
  16. "heckel.io/ntfy/util"
  17. "io"
  18. "net/http"
  19. "net/netip"
  20. "strings"
  21. "time"
  22. )
  23. // Payments in ntfy are done via Stripe.
  24. //
  25. // Pretty much all payments related things are in this file. The following processes
  26. // handle payments:
  27. //
  28. // - Checkout:
  29. // Creating a Stripe customer and subscription via the Checkout flow. This flow is only used if the
  30. // ntfy user is not already a Stripe customer. This requires redirecting to the Stripe checkout page.
  31. // It is implemented in handleAccountBillingSubscriptionCreate and the success callback
  32. // handleAccountBillingSubscriptionCreateSuccess.
  33. // - Update subscription:
  34. // Switching between Stripe subscriptions (upgrade/downgrade) is handled via
  35. // handleAccountBillingSubscriptionUpdate. This also handles proration.
  36. // - Cancel subscription (at period end):
  37. // Users can cancel the Stripe subscription via the web app at the end of the billing period. This
  38. // simply updates the subscription and Stripe will cancel it. Users cannot immediately cancel the
  39. // subscription.
  40. // - Webhooks:
  41. // Whenever a subscription changes (updated, deleted), Stripe sends us a request via a webhook.
  42. // This is used to keep the local user database fields up to date. Stripe is the source of truth.
  43. // What Stripe says is mirrored and not questioned.
  44. var (
  45. errNotAPaidTier = errors.New("tier does not have billing price identifier")
  46. errMultipleBillingSubscriptions = errors.New("cannot have multiple billing subscriptions")
  47. errNoBillingSubscription = errors.New("user does not have an active billing subscription")
  48. )
  49. var (
  50. retryUserDelays = []time.Duration{3 * time.Second, 5 * time.Second, 7 * time.Second}
  51. )
  52. // handleBillingTiersGet returns all available paid tiers, and the free tier. This is to populate the upgrade dialog
  53. // in the UI. Note that this endpoint does NOT have a user context (no v.user!).
  54. func (s *Server) handleBillingTiersGet(w http.ResponseWriter, _ *http.Request, _ *visitor) error {
  55. tiers, err := s.userManager.Tiers()
  56. if err != nil {
  57. return err
  58. }
  59. freeTier := defaultVisitorLimits(s.config)
  60. response := []*apiAccountBillingTier{
  61. {
  62. // This is a bit of a hack: This is the "Free" tier. It has no tier code, name or price.
  63. Limits: &apiAccountLimits{
  64. Messages: freeTier.MessagesLimit,
  65. MessagesExpiryDuration: int64(freeTier.MessagesExpiryDuration.Seconds()),
  66. Emails: freeTier.EmailsLimit,
  67. Reservations: freeTier.ReservationsLimit,
  68. AttachmentTotalSize: freeTier.AttachmentTotalSizeLimit,
  69. AttachmentFileSize: freeTier.AttachmentFileSizeLimit,
  70. AttachmentExpiryDuration: int64(freeTier.AttachmentExpiryDuration.Seconds()),
  71. },
  72. },
  73. }
  74. prices, err := s.priceCache.Value()
  75. if err != nil {
  76. return err
  77. }
  78. for _, tier := range tiers {
  79. priceStr, ok := prices[tier.StripePriceID]
  80. if tier.StripePriceID == "" || !ok {
  81. continue
  82. }
  83. response = append(response, &apiAccountBillingTier{
  84. Code: tier.Code,
  85. Name: tier.Name,
  86. Price: priceStr,
  87. Limits: &apiAccountLimits{
  88. Messages: tier.MessagesLimit,
  89. MessagesExpiryDuration: int64(tier.MessagesExpiryDuration.Seconds()),
  90. Emails: tier.EmailsLimit,
  91. Reservations: tier.ReservationsLimit,
  92. AttachmentTotalSize: tier.AttachmentTotalSizeLimit,
  93. AttachmentFileSize: tier.AttachmentFileSizeLimit,
  94. AttachmentExpiryDuration: int64(tier.AttachmentExpiryDuration.Seconds()),
  95. },
  96. })
  97. }
  98. return s.writeJSON(w, response)
  99. }
  100. // handleAccountBillingSubscriptionCreate creates a Stripe checkout flow to create a user subscription. The tier
  101. // will be updated by a subsequent webhook from Stripe, once the subscription becomes active.
  102. func (s *Server) handleAccountBillingSubscriptionCreate(w http.ResponseWriter, r *http.Request, v *visitor) error {
  103. if v.user.Billing.StripeSubscriptionID != "" {
  104. return errHTTPBadRequestBillingSubscriptionExists
  105. }
  106. req, err := readJSONWithLimit[apiAccountBillingSubscriptionChangeRequest](r.Body, jsonBodyBytesLimit)
  107. if err != nil {
  108. return err
  109. }
  110. tier, err := s.userManager.Tier(req.Tier)
  111. if err != nil {
  112. return err
  113. } else if tier.StripePriceID == "" {
  114. return errNotAPaidTier
  115. }
  116. log.Info("%s Creating Stripe checkout flow", logHTTPPrefix(v, r))
  117. var stripeCustomerID *string
  118. if v.user.Billing.StripeCustomerID != "" {
  119. stripeCustomerID = &v.user.Billing.StripeCustomerID
  120. stripeCustomer, err := s.stripe.GetCustomer(v.user.Billing.StripeCustomerID)
  121. if err != nil {
  122. return err
  123. } else if stripeCustomer.Subscriptions != nil && len(stripeCustomer.Subscriptions.Data) > 0 {
  124. return errMultipleBillingSubscriptions
  125. }
  126. }
  127. successURL := s.config.BaseURL + apiAccountBillingSubscriptionCheckoutSuccessTemplate
  128. params := &stripe.CheckoutSessionParams{
  129. Customer: stripeCustomerID, // A user may have previously deleted their subscription
  130. ClientReferenceID: &v.user.ID,
  131. SuccessURL: &successURL,
  132. Mode: stripe.String(string(stripe.CheckoutSessionModeSubscription)),
  133. AllowPromotionCodes: stripe.Bool(true),
  134. LineItems: []*stripe.CheckoutSessionLineItemParams{
  135. {
  136. Price: stripe.String(tier.StripePriceID),
  137. Quantity: stripe.Int64(1),
  138. },
  139. },
  140. }
  141. sess, err := s.stripe.NewCheckoutSession(params)
  142. if err != nil {
  143. return err
  144. }
  145. response := &apiAccountBillingSubscriptionCreateResponse{
  146. RedirectURL: sess.URL,
  147. }
  148. return s.writeJSON(w, response)
  149. }
  150. // handleAccountBillingSubscriptionCreateSuccess is called after the Stripe checkout session has succeeded. We use
  151. // the session ID in the URL to retrieve the Stripe subscription and update the local database. This is the first
  152. // and only time we can map the local username with the Stripe customer ID.
  153. func (s *Server) handleAccountBillingSubscriptionCreateSuccess(w http.ResponseWriter, r *http.Request, v *visitor) error {
  154. // We don't have a v.user in this endpoint, only a userManager!
  155. matches := apiAccountBillingSubscriptionCheckoutSuccessRegex.FindStringSubmatch(r.URL.Path)
  156. if len(matches) != 2 {
  157. return errHTTPInternalErrorInvalidPath
  158. }
  159. sessionID := matches[1]
  160. sess, err := s.stripe.GetSession(sessionID) // FIXME How do we rate limit this?
  161. if err != nil {
  162. return err
  163. } else if sess.Customer == nil || sess.Subscription == nil || sess.ClientReferenceID == "" {
  164. return wrapErrHTTP(errHTTPBadRequestBillingRequestInvalid, "customer or subscription not found")
  165. }
  166. sub, err := s.stripe.GetSubscription(sess.Subscription.ID)
  167. if err != nil {
  168. return err
  169. } else if sub.Items == nil || len(sub.Items.Data) != 1 || sub.Items.Data[0].Price == nil {
  170. return wrapErrHTTP(errHTTPBadRequestBillingRequestInvalid, "more than one line item in existing subscription")
  171. }
  172. tier, err := s.userManager.TierByStripePrice(sub.Items.Data[0].Price.ID)
  173. if err != nil {
  174. return err
  175. }
  176. u, err := s.userManager.UserByID(sess.ClientReferenceID)
  177. if err != nil {
  178. return err
  179. }
  180. v.SetUser(u)
  181. if err := s.updateSubscriptionAndTier(logHTTPPrefix(v, r), u, tier, sess.Customer.ID, sub.ID, string(sub.Status), sub.CurrentPeriodEnd, sub.CancelAt); err != nil {
  182. return err
  183. }
  184. http.Redirect(w, r, s.config.BaseURL+accountPath, http.StatusSeeOther)
  185. return nil
  186. }
  187. // handleAccountBillingSubscriptionUpdate updates an existing Stripe subscription to a new price, and updates
  188. // a user's tier accordingly. This endpoint only works if there is an existing subscription.
  189. func (s *Server) handleAccountBillingSubscriptionUpdate(w http.ResponseWriter, r *http.Request, v *visitor) error {
  190. if v.user.Billing.StripeSubscriptionID == "" {
  191. return errNoBillingSubscription
  192. }
  193. req, err := readJSONWithLimit[apiAccountBillingSubscriptionChangeRequest](r.Body, jsonBodyBytesLimit)
  194. if err != nil {
  195. return err
  196. }
  197. tier, err := s.userManager.Tier(req.Tier)
  198. if err != nil {
  199. return err
  200. }
  201. log.Info("%s Changing billing tier to %s (price %s) for subscription %s", logHTTPPrefix(v, r), tier.Code, tier.StripePriceID, v.user.Billing.StripeSubscriptionID)
  202. sub, err := s.stripe.GetSubscription(v.user.Billing.StripeSubscriptionID)
  203. if err != nil {
  204. return err
  205. }
  206. params := &stripe.SubscriptionParams{
  207. CancelAtPeriodEnd: stripe.Bool(false),
  208. ProrationBehavior: stripe.String(string(stripe.SubscriptionSchedulePhaseProrationBehaviorCreateProrations)),
  209. Items: []*stripe.SubscriptionItemsParams{
  210. {
  211. ID: stripe.String(sub.Items.Data[0].ID),
  212. Price: stripe.String(tier.StripePriceID),
  213. },
  214. },
  215. }
  216. _, err = s.stripe.UpdateSubscription(sub.ID, params)
  217. if err != nil {
  218. return err
  219. }
  220. return s.writeJSON(w, newSuccessResponse())
  221. }
  222. // handleAccountBillingSubscriptionDelete facilitates downgrading a paid user to a tier-less user,
  223. // and cancelling the Stripe subscription entirely
  224. func (s *Server) handleAccountBillingSubscriptionDelete(w http.ResponseWriter, r *http.Request, v *visitor) error {
  225. log.Info("%s Deleting billing subscription %s", logHTTPPrefix(v, r), v.user.Billing.StripeSubscriptionID)
  226. if v.user.Billing.StripeSubscriptionID != "" {
  227. params := &stripe.SubscriptionParams{
  228. CancelAtPeriodEnd: stripe.Bool(true),
  229. }
  230. _, err := s.stripe.UpdateSubscription(v.user.Billing.StripeSubscriptionID, params)
  231. if err != nil {
  232. return err
  233. }
  234. }
  235. return s.writeJSON(w, newSuccessResponse())
  236. }
  237. // handleAccountBillingPortalSessionCreate creates a session to the customer billing portal, and returns the
  238. // redirect URL. The billing portal allows customers to change their payment methods, and cancel the subscription.
  239. func (s *Server) handleAccountBillingPortalSessionCreate(w http.ResponseWriter, r *http.Request, v *visitor) error {
  240. if v.user.Billing.StripeCustomerID == "" {
  241. return errHTTPBadRequestNotAPaidUser
  242. }
  243. log.Info("%s Creating billing portal session", logHTTPPrefix(v, r))
  244. params := &stripe.BillingPortalSessionParams{
  245. Customer: stripe.String(v.user.Billing.StripeCustomerID),
  246. ReturnURL: stripe.String(s.config.BaseURL),
  247. }
  248. ps, err := s.stripe.NewPortalSession(params)
  249. if err != nil {
  250. return err
  251. }
  252. response := &apiAccountBillingPortalRedirectResponse{
  253. RedirectURL: ps.URL,
  254. }
  255. return s.writeJSON(w, response)
  256. }
  257. // handleAccountBillingWebhook handles incoming Stripe webhooks. It mainly keeps the local user database in sync
  258. // with the Stripe view of the world. This endpoint is authorized via the Stripe webhook secret. Note that the
  259. // visitor (v) in this endpoint is the Stripe API, so we don't have v.user available.
  260. func (s *Server) handleAccountBillingWebhook(w http.ResponseWriter, r *http.Request, _ *visitor) error {
  261. stripeSignature := r.Header.Get("Stripe-Signature")
  262. if stripeSignature == "" {
  263. return errHTTPBadRequestBillingRequestInvalid
  264. }
  265. body, err := util.Peek(r.Body, jsonBodyBytesLimit)
  266. if err != nil {
  267. return err
  268. } else if body.LimitReached {
  269. return errHTTPEntityTooLargeJSONBody
  270. }
  271. event, err := s.stripe.ConstructWebhookEvent(body.PeekedBytes, stripeSignature, s.config.StripeWebhookKey)
  272. if err != nil {
  273. return err
  274. } else if event.Data == nil || event.Data.Raw == nil {
  275. return errHTTPBadRequestBillingRequestInvalid
  276. }
  277. switch event.Type {
  278. case "customer.subscription.updated":
  279. return s.handleAccountBillingWebhookSubscriptionUpdated(event.Data.Raw)
  280. case "customer.subscription.deleted":
  281. return s.handleAccountBillingWebhookSubscriptionDeleted(event.Data.Raw)
  282. default:
  283. log.Warn("STRIPE Unhandled webhook event %s received", event.Type)
  284. return nil
  285. }
  286. }
  287. func (s *Server) handleAccountBillingWebhookSubscriptionUpdated(event json.RawMessage) error {
  288. ev, err := util.UnmarshalJSON[apiStripeSubscriptionUpdatedEvent](io.NopCloser(bytes.NewReader(event)))
  289. if err != nil {
  290. return err
  291. } else if ev.ID == "" || ev.Customer == "" || ev.Status == "" || ev.CurrentPeriodEnd == 0 || ev.Items == nil || len(ev.Items.Data) != 1 || ev.Items.Data[0].Price == nil || ev.Items.Data[0].Price.ID == "" {
  292. return errHTTPBadRequestBillingRequestInvalid
  293. }
  294. subscriptionID, priceID := ev.ID, ev.Items.Data[0].Price.ID
  295. log.Info("%s Updating subscription to status %s, with price %s", logStripePrefix(ev.Customer, ev.ID), ev.Status, priceID)
  296. userFn := func() (*user.User, error) {
  297. return s.userManager.UserByStripeCustomer(ev.Customer)
  298. }
  299. u, err := util.Retry[user.User](userFn, retryUserDelays...)
  300. if err != nil {
  301. return err
  302. }
  303. tier, err := s.userManager.TierByStripePrice(priceID)
  304. if err != nil {
  305. return err
  306. }
  307. if err := s.updateSubscriptionAndTier(logStripePrefix(ev.Customer, ev.ID), u, tier, ev.Customer, subscriptionID, ev.Status, ev.CurrentPeriodEnd, ev.CancelAt); err != nil {
  308. return err
  309. }
  310. s.publishSyncEventAsync(s.visitorFromUser(u, netip.IPv4Unspecified()))
  311. return nil
  312. }
  313. func (s *Server) handleAccountBillingWebhookSubscriptionDeleted(event json.RawMessage) error {
  314. ev, err := util.UnmarshalJSON[apiStripeSubscriptionDeletedEvent](io.NopCloser(bytes.NewReader(event)))
  315. if err != nil {
  316. return err
  317. } else if ev.Customer == "" {
  318. return errHTTPBadRequestBillingRequestInvalid
  319. }
  320. log.Info("%s Subscription deleted, downgrading to unpaid tier", logStripePrefix(ev.Customer, ev.ID))
  321. u, err := s.userManager.UserByStripeCustomer(ev.Customer)
  322. if err != nil {
  323. return err
  324. }
  325. if err := s.updateSubscriptionAndTier(logStripePrefix(ev.Customer, ev.ID), u, nil, ev.Customer, "", "", 0, 0); err != nil {
  326. return err
  327. }
  328. s.publishSyncEventAsync(s.visitorFromUser(u, netip.IPv4Unspecified()))
  329. return nil
  330. }
  331. // maybeRemoveExcessReservations deletes topic reservations for the given user (if too many for tier),
  332. // and marks associated messages for the topics as deleted. This also eventually deletes attachments.
  333. // The process relies on the manager to perform the actual deletions (see runManager).
  334. func (s *Server) maybeRemoveExcessReservations(logPrefix string, u *user.User, reservationsLimit int64) error {
  335. reservations, err := s.userManager.Reservations(u.Name)
  336. if err != nil {
  337. return err
  338. } else if int64(len(reservations)) <= reservationsLimit {
  339. return nil
  340. }
  341. topics := make([]string, 0)
  342. for i := int64(len(reservations)) - 1; i >= reservationsLimit; i-- {
  343. topics = append(topics, reservations[i].Topic)
  344. }
  345. log.Info("%s Removing excess reservations for topics %s", logPrefix, strings.Join(topics, ", "))
  346. if err := s.userManager.RemoveReservations(u.Name, topics...); err != nil {
  347. return err
  348. }
  349. if err := s.messageCache.ExpireMessages(topics...); err != nil {
  350. return err
  351. }
  352. return nil
  353. }
  354. func (s *Server) updateSubscriptionAndTier(logPrefix string, u *user.User, tier *user.Tier, customerID, subscriptionID, status string, paidUntil, cancelAt int64) error {
  355. reservationsLimit := visitorDefaultReservationsLimit
  356. if tier != nil {
  357. reservationsLimit = tier.ReservationsLimit
  358. }
  359. if err := s.maybeRemoveExcessReservations(logPrefix, u, reservationsLimit); err != nil {
  360. return err
  361. }
  362. if tier == nil {
  363. if err := s.userManager.ResetTier(u.Name); err != nil {
  364. return err
  365. }
  366. } else {
  367. if err := s.userManager.ChangeTier(u.Name, tier.Code); err != nil {
  368. return err
  369. }
  370. }
  371. // Update billing fields
  372. billing := &user.Billing{
  373. StripeCustomerID: customerID,
  374. StripeSubscriptionID: subscriptionID,
  375. StripeSubscriptionStatus: stripe.SubscriptionStatus(status),
  376. StripeSubscriptionPaidUntil: time.Unix(paidUntil, 0),
  377. StripeSubscriptionCancelAt: time.Unix(cancelAt, 0),
  378. }
  379. if err := s.userManager.ChangeBilling(u.Name, billing); err != nil {
  380. return err
  381. }
  382. return nil
  383. }
  384. // fetchStripePrices contacts the Stripe API to retrieve all prices. This is used by the server to cache the prices
  385. // in memory, and ultimately for the web app to display the price table.
  386. func (s *Server) fetchStripePrices() (map[string]string, error) {
  387. log.Debug("Caching prices from Stripe API")
  388. priceMap := make(map[string]string)
  389. prices, err := s.stripe.ListPrices(&stripe.PriceListParams{Active: stripe.Bool(true)})
  390. if err != nil {
  391. log.Warn("Fetching Stripe prices failed: %s", err.Error())
  392. return nil, err
  393. }
  394. for _, p := range prices {
  395. if p.UnitAmount%100 == 0 {
  396. priceMap[p.ID] = fmt.Sprintf("$%d", p.UnitAmount/100)
  397. } else {
  398. priceMap[p.ID] = fmt.Sprintf("$%.2f", float64(p.UnitAmount)/100)
  399. }
  400. log.Trace("- Caching price %s = %v", p.ID, priceMap[p.ID])
  401. }
  402. return priceMap, nil
  403. }
  404. // stripeAPI is a small interface to facilitate mocking of the Stripe API
  405. type stripeAPI interface {
  406. NewCheckoutSession(params *stripe.CheckoutSessionParams) (*stripe.CheckoutSession, error)
  407. NewPortalSession(params *stripe.BillingPortalSessionParams) (*stripe.BillingPortalSession, error)
  408. ListPrices(params *stripe.PriceListParams) ([]*stripe.Price, error)
  409. GetCustomer(id string) (*stripe.Customer, error)
  410. GetSession(id string) (*stripe.CheckoutSession, error)
  411. GetSubscription(id string) (*stripe.Subscription, error)
  412. UpdateSubscription(id string, params *stripe.SubscriptionParams) (*stripe.Subscription, error)
  413. CancelSubscription(id string) (*stripe.Subscription, error)
  414. ConstructWebhookEvent(payload []byte, header string, secret string) (stripe.Event, error)
  415. }
  416. // realStripeAPI is a thin shim around the Stripe functions to facilitate mocking
  417. type realStripeAPI struct{}
  418. var _ stripeAPI = (*realStripeAPI)(nil)
  419. func newStripeAPI() stripeAPI {
  420. return &realStripeAPI{}
  421. }
  422. func (s *realStripeAPI) NewCheckoutSession(params *stripe.CheckoutSessionParams) (*stripe.CheckoutSession, error) {
  423. return session.New(params)
  424. }
  425. func (s *realStripeAPI) NewPortalSession(params *stripe.BillingPortalSessionParams) (*stripe.BillingPortalSession, error) {
  426. return portalsession.New(params)
  427. }
  428. func (s *realStripeAPI) ListPrices(params *stripe.PriceListParams) ([]*stripe.Price, error) {
  429. prices := make([]*stripe.Price, 0)
  430. iter := price.List(params)
  431. for iter.Next() {
  432. prices = append(prices, iter.Price())
  433. }
  434. if iter.Err() != nil {
  435. return nil, iter.Err()
  436. }
  437. return prices, nil
  438. }
  439. func (s *realStripeAPI) GetCustomer(id string) (*stripe.Customer, error) {
  440. return customer.Get(id, nil)
  441. }
  442. func (s *realStripeAPI) GetSession(id string) (*stripe.CheckoutSession, error) {
  443. return session.Get(id, nil)
  444. }
  445. func (s *realStripeAPI) GetSubscription(id string) (*stripe.Subscription, error) {
  446. return subscription.Get(id, nil)
  447. }
  448. func (s *realStripeAPI) UpdateSubscription(id string, params *stripe.SubscriptionParams) (*stripe.Subscription, error) {
  449. return subscription.Update(id, params)
  450. }
  451. func (s *realStripeAPI) CancelSubscription(id string) (*stripe.Subscription, error) {
  452. return subscription.Cancel(id, nil)
  453. }
  454. func (s *realStripeAPI) ConstructWebhookEvent(payload []byte, header string, secret string) (stripe.Event, error) {
  455. return webhook.ConstructEvent(payload, header, secret)
  456. }