server_payments.go 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332
  1. package server
  2. import (
  3. "encoding/json"
  4. "errors"
  5. "fmt"
  6. "github.com/stripe/stripe-go/v74"
  7. portalsession "github.com/stripe/stripe-go/v74/billingportal/session"
  8. "github.com/stripe/stripe-go/v74/checkout/session"
  9. "github.com/stripe/stripe-go/v74/customer"
  10. "github.com/stripe/stripe-go/v74/price"
  11. "github.com/stripe/stripe-go/v74/subscription"
  12. "github.com/stripe/stripe-go/v74/webhook"
  13. "github.com/tidwall/gjson"
  14. "heckel.io/ntfy/log"
  15. "heckel.io/ntfy/user"
  16. "heckel.io/ntfy/util"
  17. "net/http"
  18. "net/netip"
  19. "time"
  20. )
  21. const (
  22. stripeBodyBytesLimit = 16384
  23. )
  24. func (s *Server) handleAccountBillingTiersGet(w http.ResponseWriter, r *http.Request, v *visitor) error {
  25. tiers, err := v.userManager.Tiers()
  26. if err != nil {
  27. return err
  28. }
  29. response := make([]*apiAccountBillingTier, 0)
  30. for _, tier := range tiers {
  31. if tier.StripePriceID == "" {
  32. continue
  33. }
  34. priceStr, ok := s.priceCache[tier.StripePriceID]
  35. if !ok {
  36. p, err := price.Get(tier.StripePriceID, nil)
  37. if err != nil {
  38. return err
  39. }
  40. if p.UnitAmount%100 == 0 {
  41. priceStr = fmt.Sprintf("$%d", p.UnitAmount/100)
  42. } else {
  43. priceStr = fmt.Sprintf("$%.2f", float64(p.UnitAmount)/100)
  44. }
  45. s.priceCache[tier.StripePriceID] = priceStr // FIXME race, make this sync.Map or something
  46. }
  47. response = append(response, &apiAccountBillingTier{
  48. Code: tier.Code,
  49. Name: tier.Name,
  50. Price: priceStr,
  51. Features: tier.Features,
  52. })
  53. }
  54. w.Header().Set("Content-Type", "application/json")
  55. w.Header().Set("Access-Control-Allow-Origin", "*") // FIXME remove this
  56. if err := json.NewEncoder(w).Encode(response); err != nil {
  57. return err
  58. }
  59. return nil
  60. }
  61. // handleAccountBillingSubscriptionCreate creates a Stripe checkout flow to create a user subscription. The tier
  62. // will be updated by a subsequent webhook from Stripe, once the subscription becomes active.
  63. func (s *Server) handleAccountBillingSubscriptionCreate(w http.ResponseWriter, r *http.Request, v *visitor) error {
  64. if v.user.Billing.StripeSubscriptionID != "" {
  65. return errors.New("subscription already exists") //FIXME
  66. }
  67. req, err := readJSONWithLimit[apiAccountBillingSubscriptionChangeRequest](r.Body, jsonBodyBytesLimit)
  68. if err != nil {
  69. return err
  70. }
  71. tier, err := s.userManager.Tier(req.Tier)
  72. if err != nil {
  73. return err
  74. }
  75. if tier.StripePriceID == "" {
  76. return errors.New("invalid tier") //FIXME
  77. }
  78. log.Info("Stripe: No existing subscription, creating checkout flow")
  79. var stripeCustomerID *string
  80. if v.user.Billing.StripeCustomerID != "" {
  81. stripeCustomerID = &v.user.Billing.StripeCustomerID
  82. stripeCustomer, err := customer.Get(v.user.Billing.StripeCustomerID, nil)
  83. if err != nil {
  84. return err
  85. } else if stripeCustomer.Subscriptions != nil && len(stripeCustomer.Subscriptions.Data) > 0 {
  86. return errors.New("customer cannot have more than one subscription") //FIXME
  87. }
  88. }
  89. successURL := s.config.BaseURL + apiAccountBillingSubscriptionCheckoutSuccessTemplate
  90. params := &stripe.CheckoutSessionParams{
  91. Customer: stripeCustomerID, // A user may have previously deleted their subscription
  92. ClientReferenceID: &v.user.Name,
  93. SuccessURL: &successURL,
  94. Mode: stripe.String(string(stripe.CheckoutSessionModeSubscription)),
  95. LineItems: []*stripe.CheckoutSessionLineItemParams{
  96. {
  97. Price: stripe.String(tier.StripePriceID),
  98. Quantity: stripe.Int64(1),
  99. },
  100. },
  101. /*AutomaticTax: &stripe.CheckoutSessionAutomaticTaxParams{
  102. Enabled: stripe.Bool(true),
  103. },*/
  104. }
  105. sess, err := session.New(params)
  106. if err != nil {
  107. return err
  108. }
  109. response := &apiAccountBillingSubscriptionCreateResponse{
  110. RedirectURL: sess.URL,
  111. }
  112. w.Header().Set("Content-Type", "application/json")
  113. w.Header().Set("Access-Control-Allow-Origin", "*") // FIXME remove this
  114. if err := json.NewEncoder(w).Encode(response); err != nil {
  115. return err
  116. }
  117. return nil
  118. }
  119. func (s *Server) handleAccountBillingSubscriptionCreateSuccess(w http.ResponseWriter, r *http.Request, _ *visitor) error {
  120. // We don't have a v.user in this endpoint, only a userManager!
  121. matches := apiAccountBillingSubscriptionCheckoutSuccessRegex.FindStringSubmatch(r.URL.Path)
  122. if len(matches) != 2 {
  123. return errHTTPInternalErrorInvalidPath
  124. }
  125. sessionID := matches[1]
  126. sess, err := session.Get(sessionID, nil) // FIXME how do I rate limit this?
  127. if err != nil {
  128. log.Warn("Stripe: %s", err)
  129. return errHTTPBadRequestInvalidStripeRequest
  130. } else if sess.Customer == nil || sess.Subscription == nil || sess.ClientReferenceID == "" {
  131. return wrapErrHTTP(errHTTPBadRequestInvalidStripeRequest, "customer or subscription not found")
  132. }
  133. sub, err := subscription.Get(sess.Subscription.ID, nil)
  134. if err != nil {
  135. return err
  136. } else if sub.Items == nil || len(sub.Items.Data) != 1 || sub.Items.Data[0].Price == nil {
  137. return wrapErrHTTP(errHTTPBadRequestInvalidStripeRequest, "more than one line item in existing subscription")
  138. }
  139. tier, err := s.userManager.TierByStripePrice(sub.Items.Data[0].Price.ID)
  140. if err != nil {
  141. return err
  142. }
  143. u, err := s.userManager.User(sess.ClientReferenceID)
  144. if err != nil {
  145. return err
  146. }
  147. if err := s.updateSubscriptionAndTier(u, sess.Customer.ID, sub.ID, string(sub.Status), sub.CurrentPeriodEnd, sub.CancelAt, tier.Code); err != nil {
  148. return err
  149. }
  150. http.Redirect(w, r, s.config.BaseURL+accountPath, http.StatusSeeOther)
  151. return nil
  152. }
  153. // handleAccountBillingSubscriptionUpdate updates an existing Stripe subscription to a new price, and updates
  154. // a user's tier accordingly. This endpoint only works if there is an existing subscription.
  155. func (s *Server) handleAccountBillingSubscriptionUpdate(w http.ResponseWriter, r *http.Request, v *visitor) error {
  156. if v.user.Billing.StripeSubscriptionID == "" {
  157. return errors.New("no existing subscription for user")
  158. }
  159. req, err := readJSONWithLimit[apiAccountBillingSubscriptionChangeRequest](r.Body, jsonBodyBytesLimit)
  160. if err != nil {
  161. return err
  162. }
  163. tier, err := s.userManager.Tier(req.Tier)
  164. if err != nil {
  165. return err
  166. }
  167. log.Info("Stripe: Changing tier and subscription to %s", tier.Code)
  168. sub, err := subscription.Get(v.user.Billing.StripeSubscriptionID, nil)
  169. if err != nil {
  170. return err
  171. }
  172. params := &stripe.SubscriptionParams{
  173. CancelAtPeriodEnd: stripe.Bool(false),
  174. ProrationBehavior: stripe.String(string(stripe.SubscriptionSchedulePhaseProrationBehaviorCreateProrations)),
  175. Items: []*stripe.SubscriptionItemsParams{
  176. {
  177. ID: stripe.String(sub.Items.Data[0].ID),
  178. Price: stripe.String(tier.StripePriceID),
  179. },
  180. },
  181. }
  182. _, err = subscription.Update(sub.ID, params)
  183. if err != nil {
  184. return err
  185. }
  186. w.Header().Set("Content-Type", "application/json")
  187. w.Header().Set("Access-Control-Allow-Origin", "*") // FIXME remove this
  188. if err := json.NewEncoder(w).Encode(newSuccessResponse()); err != nil {
  189. return err
  190. }
  191. return nil
  192. }
  193. // handleAccountBillingSubscriptionDelete facilitates downgrading a paid user to a tier-less user,
  194. // and cancelling the Stripe subscription entirely
  195. func (s *Server) handleAccountBillingSubscriptionDelete(w http.ResponseWriter, r *http.Request, v *visitor) error {
  196. if v.user.Billing.StripeCustomerID == "" {
  197. return errHTTPBadRequestNotAPaidUser
  198. }
  199. if v.user.Billing.StripeSubscriptionID != "" {
  200. params := &stripe.SubscriptionParams{
  201. CancelAtPeriodEnd: stripe.Bool(true),
  202. }
  203. _, err := subscription.Update(v.user.Billing.StripeSubscriptionID, params)
  204. if err != nil {
  205. return err
  206. }
  207. }
  208. return nil
  209. }
  210. func (s *Server) handleAccountBillingPortalSessionCreate(w http.ResponseWriter, r *http.Request, v *visitor) error {
  211. if v.user.Billing.StripeCustomerID == "" {
  212. return errHTTPBadRequestNotAPaidUser
  213. }
  214. params := &stripe.BillingPortalSessionParams{
  215. Customer: stripe.String(v.user.Billing.StripeCustomerID),
  216. ReturnURL: stripe.String(s.config.BaseURL),
  217. }
  218. ps, err := portalsession.New(params)
  219. if err != nil {
  220. return err
  221. }
  222. response := &apiAccountBillingPortalRedirectResponse{
  223. RedirectURL: ps.URL,
  224. }
  225. w.Header().Set("Content-Type", "application/json")
  226. w.Header().Set("Access-Control-Allow-Origin", "*") // FIXME remove this
  227. if err := json.NewEncoder(w).Encode(response); err != nil {
  228. return err
  229. }
  230. return nil
  231. }
  232. func (s *Server) handleAccountBillingWebhook(w http.ResponseWriter, r *http.Request, _ *visitor) error {
  233. // Note that the visitor (v) in this endpoint is the Stripe API, so we don't have v.user available
  234. stripeSignature := r.Header.Get("Stripe-Signature")
  235. if stripeSignature == "" {
  236. return errHTTPBadRequestInvalidStripeRequest
  237. }
  238. body, err := util.Peek(r.Body, stripeBodyBytesLimit)
  239. if err != nil {
  240. return err
  241. } else if body.LimitReached {
  242. return errHTTPEntityTooLargeJSONBody
  243. }
  244. event, err := webhook.ConstructEvent(body.PeekedBytes, stripeSignature, s.config.StripeWebhookKey)
  245. if err != nil {
  246. return errHTTPBadRequestInvalidStripeRequest
  247. } else if event.Data == nil || event.Data.Raw == nil {
  248. return errHTTPBadRequestInvalidStripeRequest
  249. }
  250. log.Info("Stripe: webhook event %s received", event.Type)
  251. switch event.Type {
  252. case "customer.subscription.updated":
  253. return s.handleAccountBillingWebhookSubscriptionUpdated(event.Data.Raw)
  254. case "customer.subscription.deleted":
  255. return s.handleAccountBillingWebhookSubscriptionDeleted(event.Data.Raw)
  256. default:
  257. return nil
  258. }
  259. }
  260. func (s *Server) handleAccountBillingWebhookSubscriptionUpdated(event json.RawMessage) error {
  261. subscriptionID := gjson.GetBytes(event, "id")
  262. customerID := gjson.GetBytes(event, "customer")
  263. status := gjson.GetBytes(event, "status")
  264. currentPeriodEnd := gjson.GetBytes(event, "current_period_end")
  265. cancelAt := gjson.GetBytes(event, "cancel_at")
  266. priceID := gjson.GetBytes(event, "items.data.0.price.id")
  267. if !subscriptionID.Exists() || !status.Exists() || !currentPeriodEnd.Exists() || !cancelAt.Exists() || !priceID.Exists() {
  268. return errHTTPBadRequestInvalidStripeRequest
  269. }
  270. log.Info("Stripe: customer %s: Updating subscription to status %s, with price %s", customerID.String(), status, priceID)
  271. u, err := s.userManager.UserByStripeCustomer(customerID.String())
  272. if err != nil {
  273. return err
  274. }
  275. tier, err := s.userManager.TierByStripePrice(priceID.String())
  276. if err != nil {
  277. return err
  278. }
  279. if err := s.updateSubscriptionAndTier(u, customerID.String(), subscriptionID.String(), status.String(), currentPeriodEnd.Int(), cancelAt.Int(), tier.Code); err != nil {
  280. return err
  281. }
  282. s.publishSyncEventAsync(s.visitorFromUser(u, netip.IPv4Unspecified()))
  283. return nil
  284. }
  285. func (s *Server) handleAccountBillingWebhookSubscriptionDeleted(event json.RawMessage) error {
  286. customerID := gjson.GetBytes(event, "customer")
  287. if !customerID.Exists() {
  288. return errHTTPBadRequestInvalidStripeRequest
  289. }
  290. log.Info("Stripe: customer %s: subscription deleted, downgrading to unpaid tier", customerID.String())
  291. u, err := s.userManager.UserByStripeCustomer(customerID.String())
  292. if err != nil {
  293. return err
  294. }
  295. if err := s.updateSubscriptionAndTier(u, customerID.String(), "", "", 0, 0, ""); err != nil {
  296. return err
  297. }
  298. s.publishSyncEventAsync(s.visitorFromUser(u, netip.IPv4Unspecified()))
  299. return nil
  300. }
  301. func (s *Server) updateSubscriptionAndTier(u *user.User, customerID, subscriptionID, status string, paidUntil, cancelAt int64, tier string) error {
  302. u.Billing.StripeCustomerID = customerID
  303. u.Billing.StripeSubscriptionID = subscriptionID
  304. u.Billing.StripeSubscriptionStatus = stripe.SubscriptionStatus(status)
  305. u.Billing.StripeSubscriptionPaidUntil = time.Unix(paidUntil, 0)
  306. u.Billing.StripeSubscriptionCancelAt = time.Unix(cancelAt, 0)
  307. if tier == "" {
  308. if err := s.userManager.ResetTier(u.Name); err != nil {
  309. return err
  310. }
  311. } else {
  312. if err := s.userManager.ChangeTier(u.Name, tier); err != nil {
  313. return err
  314. }
  315. }
  316. if err := s.userManager.ChangeBilling(u); err != nil {
  317. return err
  318. }
  319. return nil
  320. }