server_payments.go 15 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412
  1. package server
  2. import (
  3. "bytes"
  4. "encoding/json"
  5. "errors"
  6. "fmt"
  7. "github.com/stripe/stripe-go/v74"
  8. portalsession "github.com/stripe/stripe-go/v74/billingportal/session"
  9. "github.com/stripe/stripe-go/v74/checkout/session"
  10. "github.com/stripe/stripe-go/v74/customer"
  11. "github.com/stripe/stripe-go/v74/price"
  12. "github.com/stripe/stripe-go/v74/subscription"
  13. "github.com/stripe/stripe-go/v74/webhook"
  14. "heckel.io/ntfy/log"
  15. "heckel.io/ntfy/user"
  16. "heckel.io/ntfy/util"
  17. "io"
  18. "net/http"
  19. "net/netip"
  20. "time"
  21. )
  22. var (
  23. errNotAPaidTier = errors.New("tier does not have billing price identifier")
  24. errMultipleBillingSubscriptions = errors.New("cannot have multiple billing subscriptions")
  25. errNoBillingSubscription = errors.New("user does not have an active billing subscription")
  26. )
  27. // handleBillingTiersGet returns all available paid tiers, and the free tier. This is to populate the upgrade dialog
  28. // in the UI. Note that this endpoint does NOT have a user context (no v.user!).
  29. func (s *Server) handleBillingTiersGet(w http.ResponseWriter, _ *http.Request, _ *visitor) error {
  30. tiers, err := s.userManager.Tiers()
  31. if err != nil {
  32. return err
  33. }
  34. freeTier := defaultVisitorLimits(s.config)
  35. response := []*apiAccountBillingTier{
  36. {
  37. // Free tier: no code, name or price
  38. Limits: &apiAccountLimits{
  39. Messages: freeTier.MessagesLimit,
  40. MessagesExpiryDuration: int64(freeTier.MessagesExpiryDuration.Seconds()),
  41. Emails: freeTier.EmailsLimit,
  42. Reservations: freeTier.ReservationsLimit,
  43. AttachmentTotalSize: freeTier.AttachmentTotalSizeLimit,
  44. AttachmentFileSize: freeTier.AttachmentFileSizeLimit,
  45. AttachmentExpiryDuration: int64(freeTier.AttachmentExpiryDuration.Seconds()),
  46. },
  47. },
  48. }
  49. prices, err := s.priceCache.Value()
  50. if err != nil {
  51. return err
  52. }
  53. for _, tier := range tiers {
  54. priceStr, ok := prices[tier.StripePriceID]
  55. if tier.StripePriceID == "" || !ok {
  56. continue
  57. }
  58. response = append(response, &apiAccountBillingTier{
  59. Code: tier.Code,
  60. Name: tier.Name,
  61. Price: priceStr,
  62. Limits: &apiAccountLimits{
  63. Messages: tier.MessagesLimit,
  64. MessagesExpiryDuration: int64(tier.MessagesExpiryDuration.Seconds()),
  65. Emails: tier.EmailsLimit,
  66. Reservations: tier.ReservationsLimit,
  67. AttachmentTotalSize: tier.AttachmentTotalSizeLimit,
  68. AttachmentFileSize: tier.AttachmentFileSizeLimit,
  69. AttachmentExpiryDuration: int64(tier.AttachmentExpiryDuration.Seconds()),
  70. },
  71. })
  72. }
  73. return s.writeJSON(w, response)
  74. }
  75. // handleAccountBillingSubscriptionCreate creates a Stripe checkout flow to create a user subscription. The tier
  76. // will be updated by a subsequent webhook from Stripe, once the subscription becomes active.
  77. func (s *Server) handleAccountBillingSubscriptionCreate(w http.ResponseWriter, r *http.Request, v *visitor) error {
  78. if v.user.Billing.StripeSubscriptionID != "" {
  79. return errHTTPBadRequestBillingSubscriptionExists
  80. }
  81. req, err := readJSONWithLimit[apiAccountBillingSubscriptionChangeRequest](r.Body, jsonBodyBytesLimit)
  82. if err != nil {
  83. return err
  84. }
  85. tier, err := s.userManager.Tier(req.Tier)
  86. if err != nil {
  87. return err
  88. } else if tier.StripePriceID == "" {
  89. return errNotAPaidTier
  90. }
  91. log.Info("Stripe: No existing subscription, creating checkout flow")
  92. var stripeCustomerID *string
  93. if v.user.Billing.StripeCustomerID != "" {
  94. stripeCustomerID = &v.user.Billing.StripeCustomerID
  95. stripeCustomer, err := s.stripe.GetCustomer(v.user.Billing.StripeCustomerID)
  96. if err != nil {
  97. return err
  98. } else if stripeCustomer.Subscriptions != nil && len(stripeCustomer.Subscriptions.Data) > 0 {
  99. return errMultipleBillingSubscriptions
  100. }
  101. }
  102. successURL := s.config.BaseURL + apiAccountBillingSubscriptionCheckoutSuccessTemplate
  103. params := &stripe.CheckoutSessionParams{
  104. Customer: stripeCustomerID, // A user may have previously deleted their subscription
  105. ClientReferenceID: &v.user.Name,
  106. SuccessURL: &successURL,
  107. Mode: stripe.String(string(stripe.CheckoutSessionModeSubscription)),
  108. AllowPromotionCodes: stripe.Bool(true),
  109. LineItems: []*stripe.CheckoutSessionLineItemParams{
  110. {
  111. Price: stripe.String(tier.StripePriceID),
  112. Quantity: stripe.Int64(1),
  113. },
  114. },
  115. /*AutomaticTax: &stripe.CheckoutSessionAutomaticTaxParams{
  116. Enabled: stripe.Bool(true),
  117. },*/
  118. }
  119. sess, err := s.stripe.NewCheckoutSession(params)
  120. if err != nil {
  121. return err
  122. }
  123. response := &apiAccountBillingSubscriptionCreateResponse{
  124. RedirectURL: sess.URL,
  125. }
  126. return s.writeJSON(w, response)
  127. }
  128. func (s *Server) handleAccountBillingSubscriptionCreateSuccess(w http.ResponseWriter, r *http.Request, _ *visitor) error {
  129. // We don't have a v.user in this endpoint, only a userManager!
  130. matches := apiAccountBillingSubscriptionCheckoutSuccessRegex.FindStringSubmatch(r.URL.Path)
  131. if len(matches) != 2 {
  132. return errHTTPInternalErrorInvalidPath
  133. }
  134. sessionID := matches[1]
  135. sess, err := s.stripe.GetSession(sessionID) // FIXME How do we rate limit this?
  136. if err != nil {
  137. log.Warn("Stripe: %s", err)
  138. return errHTTPBadRequestBillingRequestInvalid
  139. } else if sess.Customer == nil || sess.Subscription == nil || sess.ClientReferenceID == "" {
  140. return wrapErrHTTP(errHTTPBadRequestBillingRequestInvalid, "customer or subscription not found")
  141. }
  142. sub, err := s.stripe.GetSubscription(sess.Subscription.ID)
  143. if err != nil {
  144. return err
  145. } else if sub.Items == nil || len(sub.Items.Data) != 1 || sub.Items.Data[0].Price == nil {
  146. return wrapErrHTTP(errHTTPBadRequestBillingRequestInvalid, "more than one line item in existing subscription")
  147. }
  148. tier, err := s.userManager.TierByStripePrice(sub.Items.Data[0].Price.ID)
  149. if err != nil {
  150. return err
  151. }
  152. u, err := s.userManager.User(sess.ClientReferenceID)
  153. if err != nil {
  154. return err
  155. }
  156. if err := s.updateSubscriptionAndTier(u, sess.Customer.ID, sub.ID, string(sub.Status), sub.CurrentPeriodEnd, sub.CancelAt, tier.Code); err != nil {
  157. return err
  158. }
  159. http.Redirect(w, r, s.config.BaseURL+accountPath, http.StatusSeeOther)
  160. return nil
  161. }
  162. // handleAccountBillingSubscriptionUpdate updates an existing Stripe subscription to a new price, and updates
  163. // a user's tier accordingly. This endpoint only works if there is an existing subscription.
  164. func (s *Server) handleAccountBillingSubscriptionUpdate(w http.ResponseWriter, r *http.Request, v *visitor) error {
  165. if v.user.Billing.StripeSubscriptionID == "" {
  166. return errNoBillingSubscription
  167. }
  168. req, err := readJSONWithLimit[apiAccountBillingSubscriptionChangeRequest](r.Body, jsonBodyBytesLimit)
  169. if err != nil {
  170. return err
  171. }
  172. tier, err := s.userManager.Tier(req.Tier)
  173. if err != nil {
  174. return err
  175. }
  176. log.Info("Stripe: Changing tier and subscription to %s", tier.Code)
  177. sub, err := s.stripe.GetSubscription(v.user.Billing.StripeSubscriptionID)
  178. if err != nil {
  179. return err
  180. }
  181. params := &stripe.SubscriptionParams{
  182. CancelAtPeriodEnd: stripe.Bool(false),
  183. ProrationBehavior: stripe.String(string(stripe.SubscriptionSchedulePhaseProrationBehaviorCreateProrations)),
  184. Items: []*stripe.SubscriptionItemsParams{
  185. {
  186. ID: stripe.String(sub.Items.Data[0].ID),
  187. Price: stripe.String(tier.StripePriceID),
  188. },
  189. },
  190. }
  191. _, err = s.stripe.UpdateSubscription(sub.ID, params)
  192. if err != nil {
  193. return err
  194. }
  195. return s.writeJSON(w, newSuccessResponse())
  196. }
  197. // handleAccountBillingSubscriptionDelete facilitates downgrading a paid user to a tier-less user,
  198. // and cancelling the Stripe subscription entirely
  199. func (s *Server) handleAccountBillingSubscriptionDelete(w http.ResponseWriter, r *http.Request, v *visitor) error {
  200. if v.user.Billing.StripeSubscriptionID != "" {
  201. params := &stripe.SubscriptionParams{
  202. CancelAtPeriodEnd: stripe.Bool(true),
  203. }
  204. _, err := s.stripe.UpdateSubscription(v.user.Billing.StripeSubscriptionID, params)
  205. if err != nil {
  206. return err
  207. }
  208. }
  209. return s.writeJSON(w, newSuccessResponse())
  210. }
  211. func (s *Server) handleAccountBillingPortalSessionCreate(w http.ResponseWriter, r *http.Request, v *visitor) error {
  212. if v.user.Billing.StripeCustomerID == "" {
  213. return errHTTPBadRequestNotAPaidUser
  214. }
  215. params := &stripe.BillingPortalSessionParams{
  216. Customer: stripe.String(v.user.Billing.StripeCustomerID),
  217. ReturnURL: stripe.String(s.config.BaseURL),
  218. }
  219. ps, err := s.stripe.NewPortalSession(params)
  220. if err != nil {
  221. return err
  222. }
  223. response := &apiAccountBillingPortalRedirectResponse{
  224. RedirectURL: ps.URL,
  225. }
  226. return s.writeJSON(w, response)
  227. }
  228. // handleAccountBillingWebhook handles incoming Stripe webhooks. It mainly keeps the local user database in sync
  229. // with the Stripe view of the world. This endpoint is authorized via the Stripe webhook secret. Note that the
  230. // visitor (v) in this endpoint is the Stripe API, so we don't have v.user available.
  231. func (s *Server) handleAccountBillingWebhook(w http.ResponseWriter, r *http.Request, _ *visitor) error {
  232. stripeSignature := r.Header.Get("Stripe-Signature")
  233. if stripeSignature == "" {
  234. return errHTTPBadRequestBillingRequestInvalid
  235. }
  236. body, err := util.Peek(r.Body, jsonBodyBytesLimit)
  237. if err != nil {
  238. return err
  239. } else if body.LimitReached {
  240. return errHTTPEntityTooLargeJSONBody
  241. }
  242. event, err := s.stripe.ConstructWebhookEvent(body.PeekedBytes, stripeSignature, s.config.StripeWebhookKey)
  243. if err != nil {
  244. return errHTTPBadRequestBillingRequestInvalid
  245. } else if event.Data == nil || event.Data.Raw == nil {
  246. return errHTTPBadRequestBillingRequestInvalid
  247. }
  248. log.Info("Stripe: webhook event %s received", event.Type)
  249. switch event.Type {
  250. case "customer.subscription.updated":
  251. return s.handleAccountBillingWebhookSubscriptionUpdated(event.Data.Raw)
  252. case "customer.subscription.deleted":
  253. return s.handleAccountBillingWebhookSubscriptionDeleted(event.Data.Raw)
  254. default:
  255. return nil
  256. }
  257. }
  258. func (s *Server) handleAccountBillingWebhookSubscriptionUpdated(event json.RawMessage) error {
  259. r, err := util.UnmarshalJSON[apiStripeSubscriptionUpdatedEvent](io.NopCloser(bytes.NewReader(event)))
  260. if err != nil {
  261. return err
  262. } else if r.ID == "" || r.Customer == "" || r.Status == "" || r.CurrentPeriodEnd == 0 || r.Items == nil || len(r.Items.Data) != 1 || r.Items.Data[0].Price == nil || r.Items.Data[0].Price.ID == "" {
  263. return errHTTPBadRequestBillingRequestInvalid
  264. }
  265. subscriptionID, priceID := r.ID, r.Items.Data[0].Price.ID
  266. log.Info("Stripe: customer %s: Updating subscription to status %s, with price %s", r.Customer, r.Status, priceID)
  267. u, err := s.userManager.UserByStripeCustomer(r.Customer)
  268. if err != nil {
  269. return err
  270. }
  271. tier, err := s.userManager.TierByStripePrice(priceID)
  272. if err != nil {
  273. return err
  274. }
  275. if err := s.updateSubscriptionAndTier(u, r.Customer, subscriptionID, r.Status, r.CurrentPeriodEnd, r.CancelAt, tier.Code); err != nil {
  276. return err
  277. }
  278. s.publishSyncEventAsync(s.visitorFromUser(u, netip.IPv4Unspecified()))
  279. return nil
  280. }
  281. func (s *Server) handleAccountBillingWebhookSubscriptionDeleted(event json.RawMessage) error {
  282. r, err := util.UnmarshalJSON[apiStripeSubscriptionDeletedEvent](io.NopCloser(bytes.NewReader(event)))
  283. if err != nil {
  284. return err
  285. } else if r.Customer == "" {
  286. return errHTTPBadRequestBillingRequestInvalid
  287. }
  288. log.Info("Stripe: customer %s: subscription deleted, downgrading to unpaid tier", r.Customer)
  289. u, err := s.userManager.UserByStripeCustomer(r.Customer)
  290. if err != nil {
  291. return err
  292. }
  293. if err := s.updateSubscriptionAndTier(u, r.Customer, "", "", 0, 0, ""); err != nil {
  294. return err
  295. }
  296. s.publishSyncEventAsync(s.visitorFromUser(u, netip.IPv4Unspecified()))
  297. return nil
  298. }
  299. func (s *Server) updateSubscriptionAndTier(u *user.User, customerID, subscriptionID, status string, paidUntil, cancelAt int64, tier string) error {
  300. u.Billing.StripeCustomerID = customerID
  301. u.Billing.StripeSubscriptionID = subscriptionID
  302. u.Billing.StripeSubscriptionStatus = stripe.SubscriptionStatus(status)
  303. u.Billing.StripeSubscriptionPaidUntil = time.Unix(paidUntil, 0)
  304. u.Billing.StripeSubscriptionCancelAt = time.Unix(cancelAt, 0)
  305. if tier == "" {
  306. if err := s.userManager.ResetTier(u.Name); err != nil {
  307. return err
  308. }
  309. } else {
  310. if err := s.userManager.ChangeTier(u.Name, tier); err != nil {
  311. return err
  312. }
  313. }
  314. if err := s.userManager.ChangeBilling(u); err != nil {
  315. return err
  316. }
  317. return nil
  318. }
  319. // fetchStripePrices contacts the Stripe API to retrieve all prices. This is used by the server to cache the prices
  320. // in memory, and ultimately for the web app to display the price table.
  321. func (s *Server) fetchStripePrices() (map[string]string, error) {
  322. log.Debug("Caching prices from Stripe API")
  323. priceMap := make(map[string]string)
  324. prices, err := s.stripe.ListPrices(&stripe.PriceListParams{Active: stripe.Bool(true)})
  325. if err != nil {
  326. log.Warn("Fetching Stripe prices failed: %s", err.Error())
  327. return nil, err
  328. }
  329. for _, p := range prices {
  330. if p.UnitAmount%100 == 0 {
  331. priceMap[p.ID] = fmt.Sprintf("$%d", p.UnitAmount/100)
  332. } else {
  333. priceMap[p.ID] = fmt.Sprintf("$%.2f", float64(p.UnitAmount)/100)
  334. }
  335. log.Trace("- Caching price %s = %v", p.ID, priceMap[p.ID])
  336. }
  337. return priceMap, nil
  338. }
  339. // stripeAPI is a small interface to facilitate mocking of the Stripe API
  340. type stripeAPI interface {
  341. NewCheckoutSession(params *stripe.CheckoutSessionParams) (*stripe.CheckoutSession, error)
  342. NewPortalSession(params *stripe.BillingPortalSessionParams) (*stripe.BillingPortalSession, error)
  343. ListPrices(params *stripe.PriceListParams) ([]*stripe.Price, error)
  344. GetCustomer(id string) (*stripe.Customer, error)
  345. GetSession(id string) (*stripe.CheckoutSession, error)
  346. GetSubscription(id string) (*stripe.Subscription, error)
  347. UpdateSubscription(id string, params *stripe.SubscriptionParams) (*stripe.Subscription, error)
  348. ConstructWebhookEvent(payload []byte, header string, secret string) (stripe.Event, error)
  349. }
  350. // realStripeAPI is a thin shim around the Stripe functions to facilitate mocking
  351. type realStripeAPI struct{}
  352. var _ stripeAPI = (*realStripeAPI)(nil)
  353. func newStripeAPI() stripeAPI {
  354. return &realStripeAPI{}
  355. }
  356. func (s *realStripeAPI) NewCheckoutSession(params *stripe.CheckoutSessionParams) (*stripe.CheckoutSession, error) {
  357. return session.New(params)
  358. }
  359. func (s *realStripeAPI) NewPortalSession(params *stripe.BillingPortalSessionParams) (*stripe.BillingPortalSession, error) {
  360. return portalsession.New(params)
  361. }
  362. func (s *realStripeAPI) ListPrices(params *stripe.PriceListParams) ([]*stripe.Price, error) {
  363. prices := make([]*stripe.Price, 0)
  364. iter := price.List(params)
  365. for iter.Next() {
  366. prices = append(prices, iter.Price())
  367. }
  368. if iter.Err() != nil {
  369. return nil, iter.Err()
  370. }
  371. return prices, nil
  372. }
  373. func (s *realStripeAPI) GetCustomer(id string) (*stripe.Customer, error) {
  374. return customer.Get(id, nil)
  375. }
  376. func (s *realStripeAPI) GetSession(id string) (*stripe.CheckoutSession, error) {
  377. return session.Get(id, nil)
  378. }
  379. func (s *realStripeAPI) GetSubscription(id string) (*stripe.Subscription, error) {
  380. return subscription.Get(id, nil)
  381. }
  382. func (s *realStripeAPI) UpdateSubscription(id string, params *stripe.SubscriptionParams) (*stripe.Subscription, error) {
  383. return subscription.Update(id, params)
  384. }
  385. func (s *realStripeAPI) ConstructWebhookEvent(payload []byte, header string, secret string) (stripe.Event, error) {
  386. return webhook.ConstructEvent(payload, header, secret)
  387. }