server_payments_test.go 29 KB


  1. package server
  2. import (
  3. "encoding/json"
  4. "github.com/stretchr/testify/mock"
  5. "github.com/stretchr/testify/require"
  6. "github.com/stripe/stripe-go/v74"
  7. "golang.org/x/time/rate"
  8. "heckel.io/ntfy/user"
  9. "heckel.io/ntfy/util"
  10. "io"
  11. "net/netip"
  12. "path/filepath"
  13. "strings"
  14. "sync"
  15. "testing"
  16. "time"
  17. )
  18. func TestPayments_Tiers(t *testing.T) {
  19. stripeMock := &testStripeAPI{}
  20. defer stripeMock.AssertExpectations(t)
  21. c := newTestConfigWithAuthFile(t)
  22. c.StripeSecretKey = "secret key"
  23. c.StripeWebhookKey = "webhook key"
  24. c.VisitorRequestLimitReplenish = 12 * time.Hour
  25. c.CacheDuration = 13 * time.Hour
  26. c.AttachmentFileSizeLimit = 111
  27. c.VisitorAttachmentTotalSizeLimit = 222
  28. c.AttachmentExpiryDuration = 123 * time.Second
  29. s := newTestServer(t, c)
  30. s.stripe = stripeMock
  31. // Define how the mock should react
  32. stripeMock.
  33. On("ListPrices", mock.Anything).
  34. Return([]*stripe.Price{
  35. {ID: "price_123", UnitAmount: 500},
  36. {ID: "price_124", UnitAmount: 5000},
  37. {ID: "price_456", UnitAmount: 1000},
  38. {ID: "price_457", UnitAmount: 10000},
  39. {ID: "price_999", UnitAmount: 9999},
  40. }, nil)
  41. // Create tiers
  42. require.Nil(t, s.userManager.AddTier(&user.Tier{
  43. ID: "ti_1",
  44. Code: "admin",
  45. Name: "Admin",
  46. }))
  47. require.Nil(t, s.userManager.AddTier(&user.Tier{
  48. ID: "ti_123",
  49. Code: "pro",
  50. Name: "Pro",
  51. MessageLimit: 1000,
  52. MessageExpiryDuration: time.Hour,
  53. EmailLimit: 123,
  54. ReservationLimit: 777,
  55. AttachmentFileSizeLimit: 999,
  56. AttachmentTotalSizeLimit: 888,
  57. AttachmentExpiryDuration: time.Minute,
  58. StripeMonthlyPriceID: "price_123",
  59. StripeYearlyPriceID: "price_124",
  60. }))
  61. require.Nil(t, s.userManager.AddTier(&user.Tier{
  62. ID: "ti_444",
  63. Code: "business",
  64. Name: "Business",
  65. MessageLimit: 2000,
  66. MessageExpiryDuration: 10 * time.Hour,
  67. EmailLimit: 123123,
  68. ReservationLimit: 777333,
  69. AttachmentFileSizeLimit: 999111,
  70. AttachmentTotalSizeLimit: 888111,
  71. AttachmentExpiryDuration: time.Hour,
  72. StripeMonthlyPriceID: "price_456",
  73. StripeYearlyPriceID: "price_457",
  74. }))
  75. response := request(t, s, "GET", "/v1/tiers", "", nil)
  76. require.Equal(t, 200, response.Code)
  77. var tiers []apiAccountBillingTier
  78. require.Nil(t, json.NewDecoder(response.Body).Decode(&tiers))
  79. require.Equal(t, 3, len(tiers))
  80. // Free tier
  81. tier := tiers[0]
  82. require.Equal(t, "", tier.Code)
  83. require.Equal(t, "", tier.Name)
  84. require.Equal(t, "ip", tier.Limits.Basis)
  85. require.Equal(t, int64(0), tier.Limits.Reservations)
  86. require.Equal(t, int64(2), tier.Limits.Messages) // :-(
  87. require.Equal(t, int64(13*3600), tier.Limits.MessagesExpiryDuration)
  88. require.Equal(t, int64(24), tier.Limits.Emails)
  89. require.Equal(t, int64(111), tier.Limits.AttachmentFileSize)
  90. require.Equal(t, int64(222), tier.Limits.AttachmentTotalSize)
  91. require.Equal(t, int64(123), tier.Limits.AttachmentExpiryDuration)
  92. // Admin tier is not included, because it is not paid!
  93. tier = tiers[1]
  94. require.Equal(t, "pro", tier.Code)
  95. require.Equal(t, "Pro", tier.Name)
  96. require.Equal(t, "tier", tier.Limits.Basis)
  97. require.Equal(t, int64(500), tier.Prices.Month)
  98. require.Equal(t, int64(5000), tier.Prices.Year)
  99. require.Equal(t, int64(777), tier.Limits.Reservations)
  100. require.Equal(t, int64(1000), tier.Limits.Messages)
  101. require.Equal(t, int64(3600), tier.Limits.MessagesExpiryDuration)
  102. require.Equal(t, int64(123), tier.Limits.Emails)
  103. require.Equal(t, int64(999), tier.Limits.AttachmentFileSize)
  104. require.Equal(t, int64(888), tier.Limits.AttachmentTotalSize)
  105. require.Equal(t, int64(60), tier.Limits.AttachmentExpiryDuration)
  106. tier = tiers[2]
  107. require.Equal(t, "business", tier.Code)
  108. require.Equal(t, "Business", tier.Name)
  109. require.Equal(t, int64(1000), tier.Prices.Month)
  110. require.Equal(t, int64(10000), tier.Prices.Year)
  111. require.Equal(t, "tier", tier.Limits.Basis)
  112. require.Equal(t, int64(777333), tier.Limits.Reservations)
  113. require.Equal(t, int64(2000), tier.Limits.Messages)
  114. require.Equal(t, int64(36000), tier.Limits.MessagesExpiryDuration)
  115. require.Equal(t, int64(123123), tier.Limits.Emails)
  116. require.Equal(t, int64(999111), tier.Limits.AttachmentFileSize)
  117. require.Equal(t, int64(888111), tier.Limits.AttachmentTotalSize)
  118. require.Equal(t, int64(3600), tier.Limits.AttachmentExpiryDuration)
  119. }
  120. func TestPayments_SubscriptionCreate_NotAStripeCustomer_Success(t *testing.T) {
  121. stripeMock := &testStripeAPI{}
  122. defer stripeMock.AssertExpectations(t)
  123. c := newTestConfigWithAuthFile(t)
  124. c.StripeSecretKey = "secret key"
  125. c.StripeWebhookKey = "webhook key"
  126. s := newTestServer(t, c)
  127. s.stripe = stripeMock
  128. // Define how the mock should react
  129. stripeMock.
  130. On("NewCheckoutSession", mock.Anything).
  131. Return(&stripe.CheckoutSession{URL: "https://billing.stripe.com/abc/def"}, nil)
  132. // Create tier and user
  133. require.Nil(t, s.userManager.AddTier(&user.Tier{
  134. ID: "ti_123",
  135. Code: "pro",
  136. StripeMonthlyPriceID: "price_123",
  137. }))
  138. require.Nil(t, s.userManager.AddUser("phil", "phil", user.RoleUser))
  139. // Create subscription
  140. response := request(t, s, "POST", "/v1/account/billing/subscription", `{"tier": "pro", "interval": "month"}`, map[string]string{
  141. "Authorization": util.BasicAuth("phil", "phil"),
  142. })
  143. require.Equal(t, 200, response.Code)
  144. redirectResponse, err := util.UnmarshalJSON[apiAccountBillingSubscriptionCreateResponse](io.NopCloser(response.Body))
  145. require.Nil(t, err)
  146. require.Equal(t, "https://billing.stripe.com/abc/def", redirectResponse.RedirectURL)
  147. }
  148. func TestPayments_SubscriptionCreate_StripeCustomer_Success(t *testing.T) {
  149. stripeMock := &testStripeAPI{}
  150. defer stripeMock.AssertExpectations(t)
  151. c := newTestConfigWithAuthFile(t)
  152. c.StripeSecretKey = "secret key"
  153. c.StripeWebhookKey = "webhook key"
  154. s := newTestServer(t, c)
  155. s.stripe = stripeMock
  156. // Define how the mock should react
  157. stripeMock.
  158. On("GetCustomer", "acct_123").
  159. Return(&stripe.Customer{Subscriptions: &stripe.SubscriptionList{}}, nil)
  160. stripeMock.
  161. On("NewCheckoutSession", mock.Anything).
  162. Return(&stripe.CheckoutSession{URL: "https://billing.stripe.com/abc/def"}, nil)
  163. // Create tier and user
  164. require.Nil(t, s.userManager.AddTier(&user.Tier{
  165. ID: "ti_123",
  166. Code: "pro",
  167. StripeMonthlyPriceID: "price_123",
  168. }))
  169. require.Nil(t, s.userManager.AddUser("phil", "phil", user.RoleUser))
  170. u, err := s.userManager.User("phil")
  171. require.Nil(t, err)
  172. billing := &user.Billing{
  173. StripeCustomerID: "acct_123",
  174. }
  175. require.Nil(t, s.userManager.ChangeBilling(u.Name, billing))
  176. // Create subscription
  177. response := request(t, s, "POST", "/v1/account/billing/subscription", `{"tier": "pro", "interval": "month"}`, map[string]string{
  178. "Authorization": util.BasicAuth("phil", "phil"),
  179. })
  180. require.Equal(t, 200, response.Code)
  181. redirectResponse, err := util.UnmarshalJSON[apiAccountBillingSubscriptionCreateResponse](io.NopCloser(response.Body))
  182. require.Nil(t, err)
  183. require.Equal(t, "https://billing.stripe.com/abc/def", redirectResponse.RedirectURL)
  184. }
  185. func TestPayments_AccountDelete_Cancels_Subscription(t *testing.T) {
  186. stripeMock := &testStripeAPI{}
  187. defer stripeMock.AssertExpectations(t)
  188. c := newTestConfigWithAuthFile(t)
  189. c.EnableSignup = true
  190. c.StripeSecretKey = "secret key"
  191. c.StripeWebhookKey = "webhook key"
  192. s := newTestServer(t, c)
  193. s.stripe = stripeMock
  194. // Define how the mock should react
  195. stripeMock.
  196. On("CancelSubscription", "sub_123").
  197. Return(&stripe.Subscription{}, nil)
  198. // Create tier and user
  199. require.Nil(t, s.userManager.AddTier(&user.Tier{
  200. ID: "ti_123",
  201. Code: "pro",
  202. StripeMonthlyPriceID: "price_123",
  203. }))
  204. require.Nil(t, s.userManager.AddUser("phil", "phil", user.RoleUser))
  205. u, err := s.userManager.User("phil")
  206. require.Nil(t, err)
  207. billing := &user.Billing{
  208. StripeCustomerID: "acct_123",
  209. StripeSubscriptionID: "sub_123",
  210. }
  211. require.Nil(t, s.userManager.ChangeBilling(u.Name, billing))
  212. // Delete account
  213. rr := request(t, s, "DELETE", "/v1/account", `{"password": "phil"}`, map[string]string{
  214. "Authorization": util.BasicAuth("phil", "phil"),
  215. })
  216. require.Equal(t, 200, rr.Code)
  217. rr = request(t, s, "GET", "/v1/account", "", map[string]string{
  218. "Authorization": util.BasicAuth("phil", "mypass"),
  219. })
  220. require.Equal(t, 401, rr.Code)
  221. }
  222. func TestPayments_Checkout_Success_And_Increase_Rate_Limits_Reset_Visitor(t *testing.T) {
  223. // This test is too overloaded, but it's also a great end-to-end a test.
  224. //
  225. // It tests:
  226. // - A successful checkout flow (not a paying customer -> paying customer)
  227. // - Tier-changes reset the rate limits for the user
  228. // - The request limits for tier-less user and a tier-user
  229. // - The message limits for a tier-user
  230. stripeMock := &testStripeAPI{}
  231. defer stripeMock.AssertExpectations(t)
  232. c := newTestConfigWithAuthFile(t)
  233. c.StripeSecretKey = "secret key"
  234. c.StripeWebhookKey = "webhook key"
  235. c.VisitorRequestLimitBurst = 5
  236. c.VisitorRequestLimitReplenish = time.Hour
  237. c.CacheBatchSize = 500
  238. c.CacheBatchTimeout = time.Second
  239. s := newTestServer(t, c)
  240. s.stripe = stripeMock
  241. // Create a user with a Stripe subscription and 3 reservations
  242. require.Nil(t, s.userManager.AddTier(&user.Tier{
  243. ID: "ti_123",
  244. Code: "starter",
  245. StripeMonthlyPriceID: "price_1234",
  246. ReservationLimit: 1,
  247. MessageLimit: 220, // 220 * 5% = 11 requests before rate limiting kicks in
  248. MessageExpiryDuration: time.Hour,
  249. }))
  250. require.Nil(t, s.userManager.AddUser("phil", "phil", user.RoleUser)) // No tier
  251. u, err := s.userManager.User("phil")
  252. require.Nil(t, err)
  253. // Define how the mock should react
  254. stripeMock.
  255. On("GetSession", "SOMETOKEN").
  256. Return(&stripe.CheckoutSession{
  257. ClientReferenceID: u.ID, // ntfy user ID
  258. Customer: &stripe.Customer{
  259. ID: "acct_5555",
  260. },
  261. Subscription: &stripe.Subscription{
  262. ID: "sub_1234",
  263. },
  264. }, nil)
  265. stripeMock.
  266. On("GetSubscription", "sub_1234").
  267. Return(&stripe.Subscription{
  268. ID: "sub_1234",
  269. Status: stripe.SubscriptionStatusActive,
  270. CurrentPeriodEnd: 123456789,
  271. CancelAt: 0,
  272. Items: &stripe.SubscriptionItemList{
  273. Data: []*stripe.SubscriptionItem{
  274. {
  275. Price: &stripe.Price{
  276. ID: "price_1234",
  277. Recurring: &stripe.PriceRecurring{
  278. Interval: stripe.PriceRecurringIntervalMonth,
  279. },
  280. },
  281. },
  282. },
  283. },
  284. }, nil)
  285. stripeMock.
  286. On("UpdateCustomer", "acct_5555", &stripe.CustomerParams{
  287. Params: stripe.Params{
  288. Metadata: map[string]string{
  289. "user_id": u.ID,
  290. "user_name": u.Name,
  291. },
  292. },
  293. }).
  294. Return(&stripe.Customer{}, nil)
  295. // Send messages until rate limit of free tier is hit
  296. for i := 0; i < 5; i++ {
  297. rr := request(t, s, "PUT", "/mytopic", "some message", map[string]string{
  298. "Authorization": util.BasicAuth("phil", "phil"),
  299. })
  300. require.Equal(t, 200, rr.Code)
  301. }
  302. rr := request(t, s, "PUT", "/mytopic", "some message", map[string]string{
  303. "Authorization": util.BasicAuth("phil", "phil"),
  304. })
  305. require.Equal(t, 429, rr.Code)
  306. // Verify some "before-stats"
  307. u, err = s.userManager.User("phil")
  308. require.Nil(t, err)
  309. require.Nil(t, u.Tier)
  310. require.Equal(t, "", u.Billing.StripeCustomerID)
  311. require.Equal(t, "", u.Billing.StripeSubscriptionID)
  312. require.Equal(t, stripe.SubscriptionStatus(""), u.Billing.StripeSubscriptionStatus)
  313. require.Equal(t, stripe.PriceRecurringInterval(""), u.Billing.StripeSubscriptionInterval)
  314. require.Equal(t, int64(0), u.Billing.StripeSubscriptionPaidUntil.Unix())
  315. require.Equal(t, int64(0), u.Billing.StripeSubscriptionCancelAt.Unix())
  316. require.Equal(t, int64(0), u.Stats.Messages) // Messages and emails are not persisted for no-tier users!
  317. require.Equal(t, int64(0), u.Stats.Emails)
  318. // Simulate Stripe success return URL call (no user context)
  319. rr = request(t, s, "GET", "/v1/account/billing/subscription/success/SOMETOKEN", "", nil)
  320. require.Equal(t, 303, rr.Code)
  321. // Verify that database columns were updated
  322. u, err = s.userManager.User("phil")
  323. require.Nil(t, err)
  324. require.Equal(t, "starter", u.Tier.Code) // Not "pro"
  325. require.Equal(t, "acct_5555", u.Billing.StripeCustomerID)
  326. require.Equal(t, "sub_1234", u.Billing.StripeSubscriptionID)
  327. require.Equal(t, stripe.SubscriptionStatusActive, u.Billing.StripeSubscriptionStatus)
  328. require.Equal(t, stripe.PriceRecurringIntervalMonth, u.Billing.StripeSubscriptionInterval)
  329. require.Equal(t, int64(123456789), u.Billing.StripeSubscriptionPaidUntil.Unix())
  330. require.Equal(t, int64(0), u.Billing.StripeSubscriptionCancelAt.Unix())
  331. require.Equal(t, int64(0), u.Stats.Messages)
  332. require.Equal(t, int64(0), u.Stats.Emails)
  333. // Now for the fun part: Verify that new rate limits are immediately applied
  334. // This only tests the request limiter, which kicks in before the message limiter.
  335. for i := 0; i < 11; i++ {
  336. rr := request(t, s, "PUT", "/mytopic", "some message", map[string]string{
  337. "Authorization": util.BasicAuth("phil", "phil"),
  338. })
  339. require.Equal(t, 200, rr.Code, "failed on iteration %d", i)
  340. }
  341. rr = request(t, s, "PUT", "/mytopic", "some message", map[string]string{
  342. "Authorization": util.BasicAuth("phil", "phil"),
  343. })
  344. require.Equal(t, 429, rr.Code)
  345. // Now let's test the message limiter by faking a ridiculously generous rate limiter
  346. v := s.visitor(netip.MustParseAddr("9.9.9.9"), u)
  347. v.requestLimiter = rate.NewLimiter(rate.Every(time.Millisecond), 1000000)
  348. var wg sync.WaitGroup
  349. for i := 0; i < 209; i++ {
  350. wg.Add(1)
  351. go func(i int) {
  352. defer wg.Done()
  353. rr := request(t, s, "PUT", "/mytopic", "some message", map[string]string{
  354. "Authorization": util.BasicAuth("phil", "phil"),
  355. })
  356. require.Equal(t, 200, rr.Code, "Failed on %d", i)
  357. }(i)
  358. }
  359. wg.Wait()
  360. rr = request(t, s, "PUT", "/mytopic", "some message", map[string]string{
  361. "Authorization": util.BasicAuth("phil", "phil"),
  362. })
  363. require.Equal(t, 429, rr.Code)
  364. // And now let's cross-check that the stats are correct too
  365. rr = request(t, s, "GET", "/v1/account", "", map[string]string{
  366. "Authorization": util.BasicAuth("phil", "phil"),
  367. })
  368. require.Equal(t, 200, rr.Code)
  369. account, _ := util.UnmarshalJSON[apiAccountResponse](io.NopCloser(rr.Body))
  370. require.Equal(t, int64(220), account.Limits.Messages)
  371. require.Equal(t, int64(220), account.Stats.Messages)
  372. require.Equal(t, int64(0), account.Stats.MessagesRemaining)
  373. }
  374. func TestPayments_Webhook_Subscription_Updated_Downgrade_From_PastDue_To_Active(t *testing.T) {
  375. // This tests incoming webhooks from Stripe to update a subscription:
  376. // - All Stripe columns are updated in the user table
  377. // - When downgrading, excess reservations are deleted, including messages and attachments in
  378. // the corresponding topics
  379. stripeMock := &testStripeAPI{}
  380. defer stripeMock.AssertExpectations(t)
  381. c := newTestConfigWithAuthFile(t)
  382. c.StripeSecretKey = "secret key"
  383. c.StripeWebhookKey = "webhook key"
  384. s := newTestServer(t, c)
  385. s.stripe = stripeMock
  386. // Define how the mock should react
  387. stripeMock.
  388. On("ConstructWebhookEvent", mock.Anything, "stripe signature", "webhook key").
  389. Return(jsonToStripeEvent(t, subscriptionUpdatedEventJSON), nil)
  390. // Create a user with a Stripe subscription and 3 reservations
  391. require.Nil(t, s.userManager.AddTier(&user.Tier{
  392. ID: "ti_1",
  393. Code: "starter",
  394. StripeMonthlyPriceID: "price_1234", // !
  395. ReservationLimit: 1, // !
  396. MessageLimit: 100,
  397. MessageExpiryDuration: time.Hour,
  398. AttachmentExpiryDuration: time.Hour,
  399. AttachmentFileSizeLimit: 1000000,
  400. AttachmentTotalSizeLimit: 1000000,
  401. AttachmentBandwidthLimit: 1000000,
  402. }))
  403. require.Nil(t, s.userManager.AddTier(&user.Tier{
  404. ID: "ti_2",
  405. Code: "pro",
  406. StripeMonthlyPriceID: "price_1111", // !
  407. ReservationLimit: 3, // !
  408. MessageLimit: 200,
  409. MessageExpiryDuration: time.Hour,
  410. AttachmentExpiryDuration: time.Hour,
  411. AttachmentFileSizeLimit: 1000000,
  412. AttachmentTotalSizeLimit: 1000000,
  413. AttachmentBandwidthLimit: 1000000,
  414. }))
  415. require.Nil(t, s.userManager.AddUser("phil", "phil", user.RoleUser))
  416. require.Nil(t, s.userManager.ChangeTier("phil", "pro"))
  417. require.Nil(t, s.userManager.AddReservation("phil", "atopic", user.PermissionDenyAll))
  418. require.Nil(t, s.userManager.AddReservation("phil", "ztopic", user.PermissionDenyAll))
  419. // Add billing details
  420. u, err := s.userManager.User("phil")
  421. require.Nil(t, err)
  422. billing := &user.Billing{
  423. StripeCustomerID: "acct_5555",
  424. StripeSubscriptionID: "sub_1234",
  425. StripeSubscriptionStatus: stripe.SubscriptionStatusPastDue,
  426. StripeSubscriptionInterval: stripe.PriceRecurringIntervalMonth,
  427. StripeSubscriptionPaidUntil: time.Unix(123, 0),
  428. StripeSubscriptionCancelAt: time.Unix(456, 0),
  429. }
  430. require.Nil(t, s.userManager.ChangeBilling(u.Name, billing))
  431. // Add some messages to "atopic" and "ztopic", everything in "ztopic" will be deleted
  432. rr := request(t, s, "PUT", "/atopic", "some aaa message", map[string]string{
  433. "Authorization": util.BasicAuth("phil", "phil"),
  434. })
  435. require.Equal(t, 200, rr.Code)
  436. rr = request(t, s, "PUT", "/atopic", strings.Repeat("a", 5000), map[string]string{
  437. "Authorization": util.BasicAuth("phil", "phil"),
  438. })
  439. require.Equal(t, 200, rr.Code)
  440. a2 := toMessage(t, rr.Body.String())
  441. require.FileExists(t, filepath.Join(s.config.AttachmentCacheDir, a2.ID))
  442. rr = request(t, s, "PUT", "/ztopic", "some zzz message", map[string]string{
  443. "Authorization": util.BasicAuth("phil", "phil"),
  444. })
  445. require.Equal(t, 200, rr.Code)
  446. rr = request(t, s, "PUT", "/ztopic", strings.Repeat("z", 5000), map[string]string{
  447. "Authorization": util.BasicAuth("phil", "phil"),
  448. })
  449. require.Equal(t, 200, rr.Code)
  450. z2 := toMessage(t, rr.Body.String())
  451. require.FileExists(t, filepath.Join(s.config.AttachmentCacheDir, z2.ID))
  452. // Call the webhook: This does all the magic
  453. rr = request(t, s, "POST", "/v1/account/billing/webhook", "dummy", map[string]string{
  454. "Stripe-Signature": "stripe signature",
  455. })
  456. require.Equal(t, 200, rr.Code)
  457. // Verify that database columns were updated
  458. u, err = s.userManager.User("phil")
  459. require.Nil(t, err)
  460. require.Equal(t, "starter", u.Tier.Code) // Not "pro"
  461. require.Equal(t, "acct_5555", u.Billing.StripeCustomerID)
  462. require.Equal(t, "sub_1234", u.Billing.StripeSubscriptionID)
  463. require.Equal(t, stripe.SubscriptionStatusActive, u.Billing.StripeSubscriptionStatus) // Not "past_due"
  464. require.Equal(t, stripe.PriceRecurringIntervalYear, u.Billing.StripeSubscriptionInterval) // Not "month"
  465. require.Equal(t, int64(1674268231), u.Billing.StripeSubscriptionPaidUntil.Unix()) // Updated
  466. require.Equal(t, int64(1674299999), u.Billing.StripeSubscriptionCancelAt.Unix()) // Updated
  467. // Verify that reservations were deleted
  468. r, err := s.userManager.Reservations("phil")
  469. require.Nil(t, err)
  470. require.Equal(t, 1, len(r)) // "ztopic" reservation was deleted
  471. require.Equal(t, "atopic", r[0].Topic)
  472. // Verify that messages and attachments were deleted
  473. time.Sleep(time.Second)
  474. s.execManager()
  475. ms, err := s.messageCache.Messages("atopic", sinceAllMessages, false)
  476. require.Nil(t, err)
  477. require.Equal(t, 2, len(ms))
  478. require.FileExists(t, filepath.Join(s.config.AttachmentCacheDir, a2.ID))
  479. ms, err = s.messageCache.Messages("ztopic", sinceAllMessages, false)
  480. require.Nil(t, err)
  481. require.Equal(t, 0, len(ms))
  482. require.NoFileExists(t, filepath.Join(s.config.AttachmentCacheDir, z2.ID))
  483. }
  484. func TestPayments_Webhook_Subscription_Deleted(t *testing.T) {
  485. // This tests incoming webhooks from Stripe to delete a subscription. It verifies that the database is
  486. // updated (all Stripe fields are deleted, and the tier is removed).
  487. //
  488. // It doesn't fully test the message/attachment deletion. That is tested above in the subscription update call.
  489. stripeMock := &testStripeAPI{}
  490. defer stripeMock.AssertExpectations(t)
  491. c := newTestConfigWithAuthFile(t)
  492. c.StripeSecretKey = "secret key"
  493. c.StripeWebhookKey = "webhook key"
  494. s := newTestServer(t, c)
  495. s.stripe = stripeMock
  496. // Define how the mock should react
  497. stripeMock.
  498. On("ConstructWebhookEvent", mock.Anything, "stripe signature", "webhook key").
  499. Return(jsonToStripeEvent(t, subscriptionDeletedEventJSON), nil)
  500. // Create a user with a Stripe subscription and 3 reservations
  501. require.Nil(t, s.userManager.AddTier(&user.Tier{
  502. ID: "ti_1",
  503. Code: "pro",
  504. StripeMonthlyPriceID: "price_1234",
  505. ReservationLimit: 1,
  506. }))
  507. require.Nil(t, s.userManager.AddUser("phil", "phil", user.RoleUser))
  508. require.Nil(t, s.userManager.ChangeTier("phil", "pro"))
  509. require.Nil(t, s.userManager.AddReservation("phil", "atopic", user.PermissionDenyAll))
  510. // Add billing details
  511. u, err := s.userManager.User("phil")
  512. require.Nil(t, err)
  513. require.Nil(t, s.userManager.ChangeBilling(u.Name, &user.Billing{
  514. StripeCustomerID: "acct_5555",
  515. StripeSubscriptionID: "sub_1234",
  516. StripeSubscriptionStatus: stripe.SubscriptionStatusPastDue,
  517. StripeSubscriptionInterval: stripe.PriceRecurringIntervalMonth,
  518. StripeSubscriptionPaidUntil: time.Unix(123, 0),
  519. StripeSubscriptionCancelAt: time.Unix(0, 0),
  520. }))
  521. // Call the webhook: This does all the magic
  522. rr := request(t, s, "POST", "/v1/account/billing/webhook", "dummy", map[string]string{
  523. "Stripe-Signature": "stripe signature",
  524. })
  525. require.Equal(t, 200, rr.Code)
  526. // Verify that database columns were updated
  527. u, err = s.userManager.User("phil")
  528. require.Nil(t, err)
  529. require.Nil(t, u.Tier)
  530. require.Equal(t, "acct_5555", u.Billing.StripeCustomerID)
  531. require.Equal(t, "", u.Billing.StripeSubscriptionID)
  532. require.Equal(t, stripe.SubscriptionStatus(""), u.Billing.StripeSubscriptionStatus)
  533. require.Equal(t, int64(0), u.Billing.StripeSubscriptionPaidUntil.Unix())
  534. require.Equal(t, int64(0), u.Billing.StripeSubscriptionCancelAt.Unix())
  535. // Verify that reservations were deleted
  536. r, err := s.userManager.Reservations("phil")
  537. require.Nil(t, err)
  538. require.Equal(t, 0, len(r))
  539. }
  540. func TestPayments_Subscription_Update_Different_Tier(t *testing.T) {
  541. stripeMock := &testStripeAPI{}
  542. defer stripeMock.AssertExpectations(t)
  543. c := newTestConfigWithAuthFile(t)
  544. c.StripeSecretKey = "secret key"
  545. c.StripeWebhookKey = "webhook key"
  546. s := newTestServer(t, c)
  547. s.stripe = stripeMock
  548. // Define how the mock should react
  549. stripeMock.
  550. On("GetSubscription", "sub_123").
  551. Return(&stripe.Subscription{
  552. ID: "sub_123",
  553. Items: &stripe.SubscriptionItemList{
  554. Data: []*stripe.SubscriptionItem{
  555. {
  556. ID: "someid_123",
  557. Price: &stripe.Price{ID: "price_123"},
  558. },
  559. },
  560. },
  561. }, nil)
  562. stripeMock.
  563. On("UpdateSubscription", "sub_123", &stripe.SubscriptionParams{
  564. CancelAtPeriodEnd: stripe.Bool(false),
  565. ProrationBehavior: stripe.String(string(stripe.SubscriptionSchedulePhaseProrationBehaviorAlwaysInvoice)),
  566. Items: []*stripe.SubscriptionItemsParams{
  567. {
  568. ID: stripe.String("someid_123"),
  569. Price: stripe.String("price_457"),
  570. },
  571. },
  572. }).
  573. Return(&stripe.Subscription{}, nil)
  574. // Create tier and user
  575. require.Nil(t, s.userManager.AddTier(&user.Tier{
  576. ID: "ti_123",
  577. Code: "pro",
  578. StripeMonthlyPriceID: "price_123",
  579. StripeYearlyPriceID: "price_124",
  580. }))
  581. require.Nil(t, s.userManager.AddTier(&user.Tier{
  582. ID: "ti_456",
  583. Code: "business",
  584. StripeMonthlyPriceID: "price_456",
  585. StripeYearlyPriceID: "price_457",
  586. }))
  587. require.Nil(t, s.userManager.AddUser("phil", "phil", user.RoleUser))
  588. require.Nil(t, s.userManager.ChangeTier("phil", "pro"))
  589. require.Nil(t, s.userManager.ChangeBilling("phil", &user.Billing{
  590. StripeCustomerID: "acct_123",
  591. StripeSubscriptionID: "sub_123",
  592. }))
  593. // Call endpoint to change subscription
  594. rr := request(t, s, "PUT", "/v1/account/billing/subscription", `{"tier":"business","interval":"year"}`, map[string]string{
  595. "Authorization": util.BasicAuth("phil", "phil"),
  596. })
  597. require.Equal(t, 200, rr.Code)
  598. }
  599. func TestPayments_Subscription_Delete_At_Period_End(t *testing.T) {
  600. stripeMock := &testStripeAPI{}
  601. defer stripeMock.AssertExpectations(t)
  602. c := newTestConfigWithAuthFile(t)
  603. c.StripeSecretKey = "secret key"
  604. c.StripeWebhookKey = "webhook key"
  605. s := newTestServer(t, c)
  606. s.stripe = stripeMock
  607. // Define how the mock should react
  608. stripeMock.
  609. On("UpdateSubscription", "sub_123", mock.MatchedBy(func(s *stripe.SubscriptionParams) bool {
  610. return *s.CancelAtPeriodEnd // Is true
  611. })).
  612. Return(&stripe.Subscription{}, nil)
  613. // Create user
  614. require.Nil(t, s.userManager.AddUser("phil", "phil", user.RoleUser))
  615. require.Nil(t, s.userManager.ChangeBilling("phil", &user.Billing{
  616. StripeCustomerID: "acct_123",
  617. StripeSubscriptionID: "sub_123",
  618. }))
  619. // Delete subscription
  620. rr := request(t, s, "DELETE", "/v1/account/billing/subscription", "", map[string]string{
  621. "Authorization": util.BasicAuth("phil", "phil"),
  622. })
  623. require.Equal(t, 200, rr.Code)
  624. }
  625. func TestPayments_CreatePortalSession(t *testing.T) {
  626. stripeMock := &testStripeAPI{}
  627. defer stripeMock.AssertExpectations(t)
  628. c := newTestConfigWithAuthFile(t)
  629. c.StripeSecretKey = "secret key"
  630. c.StripeWebhookKey = "webhook key"
  631. s := newTestServer(t, c)
  632. s.stripe = stripeMock
  633. // Define how the mock should react
  634. stripeMock.
  635. On("NewPortalSession", &stripe.BillingPortalSessionParams{
  636. Customer: stripe.String("acct_123"),
  637. ReturnURL: stripe.String(s.config.BaseURL),
  638. }).
  639. Return(&stripe.BillingPortalSession{
  640. URL: "https://billing.stripe.com/blablabla",
  641. }, nil)
  642. // Create user
  643. require.Nil(t, s.userManager.AddUser("phil", "phil", user.RoleUser))
  644. require.Nil(t, s.userManager.ChangeBilling("phil", &user.Billing{
  645. StripeCustomerID: "acct_123",
  646. StripeSubscriptionID: "sub_123",
  647. }))
  648. // Create portal session
  649. rr := request(t, s, "POST", "/v1/account/billing/portal", "", map[string]string{
  650. "Authorization": util.BasicAuth("phil", "phil"),
  651. })
  652. require.Equal(t, 200, rr.Code)
  653. ps, _ := util.UnmarshalJSON[apiAccountBillingPortalRedirectResponse](io.NopCloser(rr.Body))
  654. require.Equal(t, "https://billing.stripe.com/blablabla", ps.RedirectURL)
  655. }
  656. type testStripeAPI struct {
  657. mock.Mock
  658. }
  659. var _ stripeAPI = (*testStripeAPI)(nil)
  660. func (s *testStripeAPI) NewCheckoutSession(params *stripe.CheckoutSessionParams) (*stripe.CheckoutSession, error) {
  661. args := s.Called(params)
  662. return args.Get(0).(*stripe.CheckoutSession), args.Error(1)
  663. }
  664. func (s *testStripeAPI) NewPortalSession(params *stripe.BillingPortalSessionParams) (*stripe.BillingPortalSession, error) {
  665. args := s.Called(params)
  666. return args.Get(0).(*stripe.BillingPortalSession), args.Error(1)
  667. }
  668. func (s *testStripeAPI) ListPrices(params *stripe.PriceListParams) ([]*stripe.Price, error) {
  669. args := s.Called(params)
  670. return args.Get(0).([]*stripe.Price), args.Error(1)
  671. }
  672. func (s *testStripeAPI) GetCustomer(id string) (*stripe.Customer, error) {
  673. args := s.Called(id)
  674. return args.Get(0).(*stripe.Customer), args.Error(1)
  675. }
  676. func (s *testStripeAPI) GetSession(id string) (*stripe.CheckoutSession, error) {
  677. args := s.Called(id)
  678. return args.Get(0).(*stripe.CheckoutSession), args.Error(1)
  679. }
  680. func (s *testStripeAPI) GetSubscription(id string) (*stripe.Subscription, error) {
  681. args := s.Called(id)
  682. return args.Get(0).(*stripe.Subscription), args.Error(1)
  683. }
  684. func (s *testStripeAPI) UpdateCustomer(id string, params *stripe.CustomerParams) (*stripe.Customer, error) {
  685. args := s.Called(id, params)
  686. return args.Get(0).(*stripe.Customer), args.Error(1)
  687. }
  688. func (s *testStripeAPI) UpdateSubscription(id string, params *stripe.SubscriptionParams) (*stripe.Subscription, error) {
  689. args := s.Called(id, params)
  690. return args.Get(0).(*stripe.Subscription), args.Error(1)
  691. }
  692. func (s *testStripeAPI) CancelSubscription(id string) (*stripe.Subscription, error) {
  693. args := s.Called(id)
  694. return args.Get(0).(*stripe.Subscription), args.Error(1)
  695. }
  696. func (s *testStripeAPI) ConstructWebhookEvent(payload []byte, header string, secret string) (stripe.Event, error) {
  697. args := s.Called(payload, header, secret)
  698. return args.Get(0).(stripe.Event), args.Error(1)
  699. }
  700. func jsonToStripeEvent(t *testing.T, v string) stripe.Event {
  701. var e stripe.Event
  702. if err := json.Unmarshal([]byte(v), &e); err != nil {
  703. t.Fatal(err)
  704. }
  705. return e
  706. }
  707. const subscriptionUpdatedEventJSON = `
  708. {
  709. "type": "customer.subscription.updated",
  710. "data": {
  711. "object": {
  712. "id": "sub_1234",
  713. "customer": "acct_5555",
  714. "status": "active",
  715. "current_period_end": 1674268231,
  716. "cancel_at": 1674299999,
  717. "items": {
  718. "data": [
  719. {
  720. "price": {
  721. "id": "price_1234",
  722. "recurring": {
  723. "interval": "year"
  724. }
  725. }
  726. }
  727. ]
  728. }
  729. }
  730. }
  731. }`
  732. const subscriptionDeletedEventJSON = `
  733. {
  734. "type": "customer.subscription.deleted",
  735. "data": {
  736. "object": {
  737. "id": "sub_1234",
  738. "customer": "acct_5555",
  739. "status": "active",
  740. "current_period_end": 1674268231,
  741. "cancel_at": 1674299999,
  742. "items": {
  743. "data": [
  744. {
  745. "price": {
  746. "id": "price_1234",
  747. "recurring": {
  748. "interval": "month"
  749. }
  750. }
  751. }
  752. ]
  753. }
  754. }
  755. }
  756. }`