server_payments_test.go 29 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859
  1. //go:build !nopayments
  2. package server
  3. import (
  4. "encoding/json"
  5. "github.com/stretchr/testify/mock"
  6. "github.com/stretchr/testify/require"
  7. "github.com/stripe/stripe-go/v74"
  8. "golang.org/x/time/rate"
  9. "heckel.io/ntfy/v2/payments"
  10. "heckel.io/ntfy/v2/user"
  11. "heckel.io/ntfy/v2/util"
  12. "io"
  13. "net/netip"
  14. "path/filepath"
  15. "strings"
  16. "sync"
  17. "testing"
  18. "time"
  19. )
  20. func TestPayments_Tiers(t *testing.T) {
  21. stripeMock := &testStripeAPI{}
  22. defer stripeMock.AssertExpectations(t)
  23. c := newTestConfigWithAuthFile(t)
  24. c.StripeSecretKey = "secret key"
  25. c.StripeWebhookKey = "webhook key"
  26. c.VisitorRequestLimitReplenish = 12 * time.Hour
  27. c.CacheDuration = 13 * time.Hour
  28. c.AttachmentFileSizeLimit = 111
  29. c.VisitorAttachmentTotalSizeLimit = 222
  30. c.AttachmentExpiryDuration = 123 * time.Second
  31. s := newTestServer(t, c)
  32. s.stripe = stripeMock
  33. // Define how the mock should react
  34. stripeMock.
  35. On("ListPrices", mock.Anything).
  36. Return([]*stripe.Price{
  37. {ID: "price_123", UnitAmount: 500},
  38. {ID: "price_124", UnitAmount: 5000},
  39. {ID: "price_456", UnitAmount: 1000},
  40. {ID: "price_457", UnitAmount: 10000},
  41. {ID: "price_999", UnitAmount: 9999},
  42. }, nil)
  43. // Create tiers
  44. require.Nil(t, s.userManager.AddTier(&user.Tier{
  45. ID: "ti_1",
  46. Code: "admin",
  47. Name: "Admin",
  48. }))
  49. require.Nil(t, s.userManager.AddTier(&user.Tier{
  50. ID: "ti_123",
  51. Code: "pro",
  52. Name: "Pro",
  53. MessageLimit: 1000,
  54. MessageExpiryDuration: time.Hour,
  55. EmailLimit: 123,
  56. ReservationLimit: 777,
  57. AttachmentFileSizeLimit: 999,
  58. AttachmentTotalSizeLimit: 888,
  59. AttachmentExpiryDuration: time.Minute,
  60. StripeMonthlyPriceID: "price_123",
  61. StripeYearlyPriceID: "price_124",
  62. }))
  63. require.Nil(t, s.userManager.AddTier(&user.Tier{
  64. ID: "ti_444",
  65. Code: "business",
  66. Name: "Business",
  67. MessageLimit: 2000,
  68. MessageExpiryDuration: 10 * time.Hour,
  69. EmailLimit: 123123,
  70. ReservationLimit: 777333,
  71. AttachmentFileSizeLimit: 999111,
  72. AttachmentTotalSizeLimit: 888111,
  73. AttachmentExpiryDuration: time.Hour,
  74. StripeMonthlyPriceID: "price_456",
  75. StripeYearlyPriceID: "price_457",
  76. }))
  77. response := request(t, s, "GET", "/v1/tiers", "", nil)
  78. require.Equal(t, 200, response.Code)
  79. var tiers []apiAccountBillingTier
  80. require.Nil(t, json.NewDecoder(response.Body).Decode(&tiers))
  81. require.Equal(t, 3, len(tiers))
  82. // Free tier
  83. tier := tiers[0]
  84. require.Equal(t, "", tier.Code)
  85. require.Equal(t, "", tier.Name)
  86. require.Equal(t, "ip", tier.Limits.Basis)
  87. require.Equal(t, int64(0), tier.Limits.Reservations)
  88. require.Equal(t, int64(2), tier.Limits.Messages) // :-(
  89. require.Equal(t, int64(13*3600), tier.Limits.MessagesExpiryDuration)
  90. require.Equal(t, int64(24), tier.Limits.Emails)
  91. require.Equal(t, int64(111), tier.Limits.AttachmentFileSize)
  92. require.Equal(t, int64(222), tier.Limits.AttachmentTotalSize)
  93. require.Equal(t, int64(123), tier.Limits.AttachmentExpiryDuration)
  94. // Admin tier is not included, because it is not paid!
  95. tier = tiers[1]
  96. require.Equal(t, "pro", tier.Code)
  97. require.Equal(t, "Pro", tier.Name)
  98. require.Equal(t, "tier", tier.Limits.Basis)
  99. require.Equal(t, int64(500), tier.Prices.Month)
  100. require.Equal(t, int64(5000), tier.Prices.Year)
  101. require.Equal(t, int64(777), tier.Limits.Reservations)
  102. require.Equal(t, int64(1000), tier.Limits.Messages)
  103. require.Equal(t, int64(3600), tier.Limits.MessagesExpiryDuration)
  104. require.Equal(t, int64(123), tier.Limits.Emails)
  105. require.Equal(t, int64(999), tier.Limits.AttachmentFileSize)
  106. require.Equal(t, int64(888), tier.Limits.AttachmentTotalSize)
  107. require.Equal(t, int64(60), tier.Limits.AttachmentExpiryDuration)
  108. tier = tiers[2]
  109. require.Equal(t, "business", tier.Code)
  110. require.Equal(t, "Business", tier.Name)
  111. require.Equal(t, int64(1000), tier.Prices.Month)
  112. require.Equal(t, int64(10000), tier.Prices.Year)
  113. require.Equal(t, "tier", tier.Limits.Basis)
  114. require.Equal(t, int64(777333), tier.Limits.Reservations)
  115. require.Equal(t, int64(2000), tier.Limits.Messages)
  116. require.Equal(t, int64(36000), tier.Limits.MessagesExpiryDuration)
  117. require.Equal(t, int64(123123), tier.Limits.Emails)
  118. require.Equal(t, int64(999111), tier.Limits.AttachmentFileSize)
  119. require.Equal(t, int64(888111), tier.Limits.AttachmentTotalSize)
  120. require.Equal(t, int64(3600), tier.Limits.AttachmentExpiryDuration)
  121. }
  122. func TestPayments_SubscriptionCreate_NotAStripeCustomer_Success(t *testing.T) {
  123. stripeMock := &testStripeAPI{}
  124. defer stripeMock.AssertExpectations(t)
  125. c := newTestConfigWithAuthFile(t)
  126. c.StripeSecretKey = "secret key"
  127. c.StripeWebhookKey = "webhook key"
  128. s := newTestServer(t, c)
  129. s.stripe = stripeMock
  130. // Define how the mock should react
  131. stripeMock.
  132. On("NewCheckoutSession", mock.Anything).
  133. Return(&stripe.CheckoutSession{URL: "https://billing.stripe.com/abc/def"}, nil)
  134. // Create tier and user
  135. require.Nil(t, s.userManager.AddTier(&user.Tier{
  136. ID: "ti_123",
  137. Code: "pro",
  138. StripeMonthlyPriceID: "price_123",
  139. }))
  140. require.Nil(t, s.userManager.AddUser("phil", "phil", user.RoleUser, false))
  141. // Create subscription
  142. response := request(t, s, "POST", "/v1/account/billing/subscription", `{"tier": "pro", "interval": "month"}`, map[string]string{
  143. "Authorization": util.BasicAuth("phil", "phil"),
  144. })
  145. require.Equal(t, 200, response.Code)
  146. redirectResponse, err := util.UnmarshalJSON[apiAccountBillingSubscriptionCreateResponse](io.NopCloser(response.Body))
  147. require.Nil(t, err)
  148. require.Equal(t, "https://billing.stripe.com/abc/def", redirectResponse.RedirectURL)
  149. }
  150. func TestPayments_SubscriptionCreate_StripeCustomer_Success(t *testing.T) {
  151. stripeMock := &testStripeAPI{}
  152. defer stripeMock.AssertExpectations(t)
  153. c := newTestConfigWithAuthFile(t)
  154. c.StripeSecretKey = "secret key"
  155. c.StripeWebhookKey = "webhook key"
  156. s := newTestServer(t, c)
  157. s.stripe = stripeMock
  158. // Define how the mock should react
  159. stripeMock.
  160. On("GetCustomer", "acct_123").
  161. Return(&stripe.Customer{Subscriptions: &stripe.SubscriptionList{}}, nil)
  162. stripeMock.
  163. On("NewCheckoutSession", mock.Anything).
  164. Return(&stripe.CheckoutSession{URL: "https://billing.stripe.com/abc/def"}, nil)
  165. // Create tier and user
  166. require.Nil(t, s.userManager.AddTier(&user.Tier{
  167. ID: "ti_123",
  168. Code: "pro",
  169. StripeMonthlyPriceID: "price_123",
  170. }))
  171. require.Nil(t, s.userManager.AddUser("phil", "phil", user.RoleUser, false))
  172. u, err := s.userManager.User("phil")
  173. require.Nil(t, err)
  174. billing := &user.Billing{
  175. StripeCustomerID: "acct_123",
  176. }
  177. require.Nil(t, s.userManager.ChangeBilling(u.Name, billing))
  178. // Create subscription
  179. response := request(t, s, "POST", "/v1/account/billing/subscription", `{"tier": "pro", "interval": "month"}`, map[string]string{
  180. "Authorization": util.BasicAuth("phil", "phil"),
  181. })
  182. require.Equal(t, 200, response.Code)
  183. redirectResponse, err := util.UnmarshalJSON[apiAccountBillingSubscriptionCreateResponse](io.NopCloser(response.Body))
  184. require.Nil(t, err)
  185. require.Equal(t, "https://billing.stripe.com/abc/def", redirectResponse.RedirectURL)
  186. }
  187. func TestPayments_AccountDelete_Cancels_Subscription(t *testing.T) {
  188. stripeMock := &testStripeAPI{}
  189. defer stripeMock.AssertExpectations(t)
  190. c := newTestConfigWithAuthFile(t)
  191. c.EnableSignup = true
  192. c.StripeSecretKey = "secret key"
  193. c.StripeWebhookKey = "webhook key"
  194. s := newTestServer(t, c)
  195. s.stripe = stripeMock
  196. // Define how the mock should react
  197. stripeMock.
  198. On("CancelSubscription", "sub_123").
  199. Return(&stripe.Subscription{}, nil)
  200. // Create tier and user
  201. require.Nil(t, s.userManager.AddTier(&user.Tier{
  202. ID: "ti_123",
  203. Code: "pro",
  204. StripeMonthlyPriceID: "price_123",
  205. }))
  206. require.Nil(t, s.userManager.AddUser("phil", "phil", user.RoleUser, false))
  207. u, err := s.userManager.User("phil")
  208. require.Nil(t, err)
  209. billing := &user.Billing{
  210. StripeCustomerID: "acct_123",
  211. StripeSubscriptionID: "sub_123",
  212. }
  213. require.Nil(t, s.userManager.ChangeBilling(u.Name, billing))
  214. // Delete account
  215. rr := request(t, s, "DELETE", "/v1/account", `{"password": "phil"}`, map[string]string{
  216. "Authorization": util.BasicAuth("phil", "phil"),
  217. })
  218. require.Equal(t, 200, rr.Code)
  219. rr = request(t, s, "GET", "/v1/account", "", map[string]string{
  220. "Authorization": util.BasicAuth("phil", "mypass"),
  221. })
  222. require.Equal(t, 401, rr.Code)
  223. }
  224. func TestPayments_Checkout_Success_And_Increase_Rate_Limits_Reset_Visitor(t *testing.T) {
  225. // This test is too overloaded, but it's also a great end-to-end a test.
  226. //
  227. // It tests:
  228. // - A successful checkout flow (not a paying customer -> paying customer)
  229. // - Tier-changes reset the rate limits for the user
  230. // - The request limits for tier-less user and a tier-user
  231. // - The message limits for a tier-user
  232. stripeMock := &testStripeAPI{}
  233. defer stripeMock.AssertExpectations(t)
  234. c := newTestConfigWithAuthFile(t)
  235. c.StripeSecretKey = "secret key"
  236. c.StripeWebhookKey = "webhook key"
  237. c.VisitorRequestLimitBurst = 5
  238. c.VisitorRequestLimitReplenish = time.Hour
  239. c.CacheBatchSize = 500
  240. c.CacheBatchTimeout = time.Second
  241. s := newTestServer(t, c)
  242. s.stripe = stripeMock
  243. // Create a user with a Stripe subscription and 3 reservations
  244. require.Nil(t, s.userManager.AddTier(&user.Tier{
  245. ID: "ti_123",
  246. Code: "starter",
  247. StripeMonthlyPriceID: "price_1234",
  248. ReservationLimit: 1,
  249. MessageLimit: 220, // 220 * 5% = 11 requests before rate limiting kicks in
  250. MessageExpiryDuration: time.Hour,
  251. }))
  252. require.Nil(t, s.userManager.AddUser("phil", "phil", user.RoleUser, false)) // No tier
  253. u, err := s.userManager.User("phil")
  254. require.Nil(t, err)
  255. // Define how the mock should react
  256. stripeMock.
  257. On("GetSession", "SOMETOKEN").
  258. Return(&stripe.CheckoutSession{
  259. ClientReferenceID: u.ID, // ntfy user ID
  260. Customer: &stripe.Customer{
  261. ID: "acct_5555",
  262. },
  263. Subscription: &stripe.Subscription{
  264. ID: "sub_1234",
  265. },
  266. }, nil)
  267. stripeMock.
  268. On("GetSubscription", "sub_1234").
  269. Return(&stripe.Subscription{
  270. ID: "sub_1234",
  271. Status: stripe.SubscriptionStatusActive,
  272. CurrentPeriodEnd: 123456789,
  273. CancelAt: 0,
  274. Items: &stripe.SubscriptionItemList{
  275. Data: []*stripe.SubscriptionItem{
  276. {
  277. Price: &stripe.Price{
  278. ID: "price_1234",
  279. Recurring: &stripe.PriceRecurring{
  280. Interval: stripe.PriceRecurringIntervalMonth,
  281. },
  282. },
  283. },
  284. },
  285. },
  286. }, nil)
  287. stripeMock.
  288. On("UpdateCustomer", "acct_5555", &stripe.CustomerParams{
  289. Params: stripe.Params{
  290. Metadata: map[string]string{
  291. "user_id": u.ID,
  292. "user_name": u.Name,
  293. },
  294. },
  295. }).
  296. Return(&stripe.Customer{}, nil)
  297. // Send messages until rate limit of free tier is hit
  298. for i := 0; i < 5; i++ {
  299. rr := request(t, s, "PUT", "/mytopic", "some message", map[string]string{
  300. "Authorization": util.BasicAuth("phil", "phil"),
  301. })
  302. require.Equal(t, 200, rr.Code)
  303. }
  304. rr := request(t, s, "PUT", "/mytopic", "some message", map[string]string{
  305. "Authorization": util.BasicAuth("phil", "phil"),
  306. })
  307. require.Equal(t, 429, rr.Code)
  308. // Verify some "before-stats"
  309. u, err = s.userManager.User("phil")
  310. require.Nil(t, err)
  311. require.Nil(t, u.Tier)
  312. require.Equal(t, "", u.Billing.StripeCustomerID)
  313. require.Equal(t, "", u.Billing.StripeSubscriptionID)
  314. require.Equal(t, payments.SubscriptionStatus(""), u.Billing.StripeSubscriptionStatus)
  315. require.Equal(t, payments.PriceRecurringInterval(""), u.Billing.StripeSubscriptionInterval)
  316. require.Equal(t, int64(0), u.Billing.StripeSubscriptionPaidUntil.Unix())
  317. require.Equal(t, int64(0), u.Billing.StripeSubscriptionCancelAt.Unix())
  318. require.Equal(t, int64(0), u.Stats.Messages) // Messages and emails are not persisted for no-tier users!
  319. require.Equal(t, int64(0), u.Stats.Emails)
  320. // Simulate Stripe success return URL call (no user context)
  321. rr = request(t, s, "GET", "/v1/account/billing/subscription/success/SOMETOKEN", "", nil)
  322. require.Equal(t, 303, rr.Code)
  323. // Verify that database columns were updated
  324. u, err = s.userManager.User("phil")
  325. require.Nil(t, err)
  326. require.Equal(t, "starter", u.Tier.Code) // Not "pro"
  327. require.Equal(t, "acct_5555", u.Billing.StripeCustomerID)
  328. require.Equal(t, "sub_1234", u.Billing.StripeSubscriptionID)
  329. require.Equal(t, payments.SubscriptionStatus(stripe.SubscriptionStatusActive), u.Billing.StripeSubscriptionStatus)
  330. require.Equal(t, payments.PriceRecurringInterval(stripe.PriceRecurringIntervalMonth), u.Billing.StripeSubscriptionInterval)
  331. require.Equal(t, int64(123456789), u.Billing.StripeSubscriptionPaidUntil.Unix())
  332. require.Equal(t, int64(0), u.Billing.StripeSubscriptionCancelAt.Unix())
  333. require.Equal(t, int64(0), u.Stats.Messages)
  334. require.Equal(t, int64(0), u.Stats.Emails)
  335. // Now for the fun part: Verify that new rate limits are immediately applied
  336. // This only tests the request limiter, which kicks in before the message limiter.
  337. for i := 0; i < 11; i++ {
  338. rr := request(t, s, "PUT", "/mytopic", "some message", map[string]string{
  339. "Authorization": util.BasicAuth("phil", "phil"),
  340. })
  341. require.Equal(t, 200, rr.Code, "failed on iteration %d", i)
  342. }
  343. rr = request(t, s, "PUT", "/mytopic", "some message", map[string]string{
  344. "Authorization": util.BasicAuth("phil", "phil"),
  345. })
  346. require.Equal(t, 429, rr.Code)
  347. // Now let's test the message limiter by faking a ridiculously generous rate limiter
  348. v := s.visitor(netip.MustParseAddr("9.9.9.9"), u)
  349. v.requestLimiter = rate.NewLimiter(rate.Every(time.Millisecond), 1000000)
  350. var wg sync.WaitGroup
  351. for i := 0; i < 209; i++ {
  352. wg.Add(1)
  353. go func(i int) {
  354. defer wg.Done()
  355. rr := request(t, s, "PUT", "/mytopic", "some message", map[string]string{
  356. "Authorization": util.BasicAuth("phil", "phil"),
  357. })
  358. require.Equal(t, 200, rr.Code, "Failed on %d", i)
  359. }(i)
  360. }
  361. wg.Wait()
  362. rr = request(t, s, "PUT", "/mytopic", "some message", map[string]string{
  363. "Authorization": util.BasicAuth("phil", "phil"),
  364. })
  365. require.Equal(t, 429, rr.Code)
  366. // And now let's cross-check that the stats are correct too
  367. rr = request(t, s, "GET", "/v1/account", "", map[string]string{
  368. "Authorization": util.BasicAuth("phil", "phil"),
  369. })
  370. require.Equal(t, 200, rr.Code)
  371. account, _ := util.UnmarshalJSON[apiAccountResponse](io.NopCloser(rr.Body))
  372. require.Equal(t, int64(220), account.Limits.Messages)
  373. require.Equal(t, int64(220), account.Stats.Messages)
  374. require.Equal(t, int64(0), account.Stats.MessagesRemaining)
  375. }
  376. func TestPayments_Webhook_Subscription_Updated_Downgrade_From_PastDue_To_Active(t *testing.T) {
  377. t.Parallel()
  378. // This tests incoming webhooks from Stripe to update a subscription:
  379. // - All Stripe columns are updated in the user table
  380. // - When downgrading, excess reservations are deleted, including messages and attachments in
  381. // the corresponding topics
  382. stripeMock := &testStripeAPI{}
  383. defer stripeMock.AssertExpectations(t)
  384. c := newTestConfigWithAuthFile(t)
  385. c.StripeSecretKey = "secret key"
  386. c.StripeWebhookKey = "webhook key"
  387. s := newTestServer(t, c)
  388. s.stripe = stripeMock
  389. // Define how the mock should react
  390. stripeMock.
  391. On("ConstructWebhookEvent", mock.Anything, "stripe signature", "webhook key").
  392. Return(jsonToStripeEvent(t, subscriptionUpdatedEventJSON), nil)
  393. // Create a user with a Stripe subscription and 3 reservations
  394. require.Nil(t, s.userManager.AddTier(&user.Tier{
  395. ID: "ti_1",
  396. Code: "starter",
  397. StripeMonthlyPriceID: "price_1234", // !
  398. ReservationLimit: 1, // !
  399. MessageLimit: 100,
  400. MessageExpiryDuration: time.Hour,
  401. AttachmentExpiryDuration: time.Hour,
  402. AttachmentFileSizeLimit: 1000000,
  403. AttachmentTotalSizeLimit: 1000000,
  404. AttachmentBandwidthLimit: 1000000,
  405. }))
  406. require.Nil(t, s.userManager.AddTier(&user.Tier{
  407. ID: "ti_2",
  408. Code: "pro",
  409. StripeMonthlyPriceID: "price_1111", // !
  410. ReservationLimit: 3, // !
  411. MessageLimit: 200,
  412. MessageExpiryDuration: time.Hour,
  413. AttachmentExpiryDuration: time.Hour,
  414. AttachmentFileSizeLimit: 1000000,
  415. AttachmentTotalSizeLimit: 1000000,
  416. AttachmentBandwidthLimit: 1000000,
  417. }))
  418. require.Nil(t, s.userManager.AddUser("phil", "phil", user.RoleUser, false))
  419. require.Nil(t, s.userManager.ChangeTier("phil", "pro"))
  420. require.Nil(t, s.userManager.AddReservation("phil", "atopic", user.PermissionDenyAll))
  421. require.Nil(t, s.userManager.AddReservation("phil", "ztopic", user.PermissionDenyAll))
  422. // Add billing details
  423. u, err := s.userManager.User("phil")
  424. require.Nil(t, err)
  425. billing := &user.Billing{
  426. StripeCustomerID: "acct_5555",
  427. StripeSubscriptionID: "sub_1234",
  428. StripeSubscriptionStatus: payments.SubscriptionStatus(stripe.SubscriptionStatusPastDue),
  429. StripeSubscriptionInterval: payments.PriceRecurringInterval(stripe.PriceRecurringIntervalMonth),
  430. StripeSubscriptionPaidUntil: time.Unix(123, 0),
  431. StripeSubscriptionCancelAt: time.Unix(456, 0),
  432. }
  433. require.Nil(t, s.userManager.ChangeBilling(u.Name, billing))
  434. // Add some messages to "atopic" and "ztopic", everything in "ztopic" will be deleted
  435. rr := request(t, s, "PUT", "/atopic", "some aaa message", map[string]string{
  436. "Authorization": util.BasicAuth("phil", "phil"),
  437. })
  438. require.Equal(t, 200, rr.Code)
  439. rr = request(t, s, "PUT", "/atopic", strings.Repeat("a", 5000), map[string]string{
  440. "Authorization": util.BasicAuth("phil", "phil"),
  441. })
  442. require.Equal(t, 200, rr.Code)
  443. a2 := toMessage(t, rr.Body.String())
  444. require.FileExists(t, filepath.Join(s.config.AttachmentCacheDir, a2.ID))
  445. rr = request(t, s, "PUT", "/ztopic", "some zzz message", map[string]string{
  446. "Authorization": util.BasicAuth("phil", "phil"),
  447. })
  448. require.Equal(t, 200, rr.Code)
  449. rr = request(t, s, "PUT", "/ztopic", strings.Repeat("z", 5000), map[string]string{
  450. "Authorization": util.BasicAuth("phil", "phil"),
  451. })
  452. require.Equal(t, 200, rr.Code)
  453. z2 := toMessage(t, rr.Body.String())
  454. require.FileExists(t, filepath.Join(s.config.AttachmentCacheDir, z2.ID))
  455. // Call the webhook: This does all the magic
  456. rr = request(t, s, "POST", "/v1/account/billing/webhook", "dummy", map[string]string{
  457. "Stripe-Signature": "stripe signature",
  458. })
  459. require.Equal(t, 200, rr.Code)
  460. // Verify that database columns were updated
  461. u, err = s.userManager.User("phil")
  462. require.Nil(t, err)
  463. require.Equal(t, "starter", u.Tier.Code) // Not "pro"
  464. require.Equal(t, "acct_5555", u.Billing.StripeCustomerID)
  465. require.Equal(t, "sub_1234", u.Billing.StripeSubscriptionID)
  466. require.Equal(t, payments.SubscriptionStatus(stripe.SubscriptionStatusActive), u.Billing.StripeSubscriptionStatus) // Not "past_due"
  467. require.Equal(t, payments.PriceRecurringInterval(stripe.PriceRecurringIntervalYear), u.Billing.StripeSubscriptionInterval) // Not "month"
  468. require.Equal(t, int64(1674268231), u.Billing.StripeSubscriptionPaidUntil.Unix()) // Updated
  469. require.Equal(t, int64(1674299999), u.Billing.StripeSubscriptionCancelAt.Unix()) // Updated
  470. // Verify that reservations were deleted
  471. r, err := s.userManager.Reservations("phil")
  472. require.Nil(t, err)
  473. require.Equal(t, 1, len(r)) // "ztopic" reservation was deleted
  474. require.Equal(t, "atopic", r[0].Topic)
  475. // Verify that messages and attachments were deleted
  476. time.Sleep(time.Second)
  477. s.execManager()
  478. ms, err := s.messageCache.Messages("atopic", sinceAllMessages, false)
  479. require.Nil(t, err)
  480. require.Equal(t, 2, len(ms))
  481. require.FileExists(t, filepath.Join(s.config.AttachmentCacheDir, a2.ID))
  482. ms, err = s.messageCache.Messages("ztopic", sinceAllMessages, false)
  483. require.Nil(t, err)
  484. require.Equal(t, 0, len(ms))
  485. require.NoFileExists(t, filepath.Join(s.config.AttachmentCacheDir, z2.ID))
  486. }
  487. func TestPayments_Webhook_Subscription_Deleted(t *testing.T) {
  488. // This tests incoming webhooks from Stripe to delete a subscription. It verifies that the database is
  489. // updated (all Stripe fields are deleted, and the tier is removed).
  490. //
  491. // It doesn't fully test the message/attachment deletion. That is tested above in the subscription update call.
  492. stripeMock := &testStripeAPI{}
  493. defer stripeMock.AssertExpectations(t)
  494. c := newTestConfigWithAuthFile(t)
  495. c.StripeSecretKey = "secret key"
  496. c.StripeWebhookKey = "webhook key"
  497. s := newTestServer(t, c)
  498. s.stripe = stripeMock
  499. // Define how the mock should react
  500. stripeMock.
  501. On("ConstructWebhookEvent", mock.Anything, "stripe signature", "webhook key").
  502. Return(jsonToStripeEvent(t, subscriptionDeletedEventJSON), nil)
  503. // Create a user with a Stripe subscription and 3 reservations
  504. require.Nil(t, s.userManager.AddTier(&user.Tier{
  505. ID: "ti_1",
  506. Code: "pro",
  507. StripeMonthlyPriceID: "price_1234",
  508. ReservationLimit: 1,
  509. }))
  510. require.Nil(t, s.userManager.AddUser("phil", "phil", user.RoleUser, false))
  511. require.Nil(t, s.userManager.ChangeTier("phil", "pro"))
  512. require.Nil(t, s.userManager.AddReservation("phil", "atopic", user.PermissionDenyAll))
  513. // Add billing details
  514. u, err := s.userManager.User("phil")
  515. require.Nil(t, err)
  516. require.Nil(t, s.userManager.ChangeBilling(u.Name, &user.Billing{
  517. StripeCustomerID: "acct_5555",
  518. StripeSubscriptionID: "sub_1234",
  519. StripeSubscriptionStatus: payments.SubscriptionStatus(stripe.SubscriptionStatusPastDue),
  520. StripeSubscriptionInterval: payments.PriceRecurringInterval(stripe.PriceRecurringIntervalMonth),
  521. StripeSubscriptionPaidUntil: time.Unix(123, 0),
  522. StripeSubscriptionCancelAt: time.Unix(0, 0),
  523. }))
  524. // Call the webhook: This does all the magic
  525. rr := request(t, s, "POST", "/v1/account/billing/webhook", "dummy", map[string]string{
  526. "Stripe-Signature": "stripe signature",
  527. })
  528. require.Equal(t, 200, rr.Code)
  529. // Verify that database columns were updated
  530. u, err = s.userManager.User("phil")
  531. require.Nil(t, err)
  532. require.Nil(t, u.Tier)
  533. require.Equal(t, "acct_5555", u.Billing.StripeCustomerID)
  534. require.Equal(t, "", u.Billing.StripeSubscriptionID)
  535. require.Equal(t, payments.SubscriptionStatus(""), u.Billing.StripeSubscriptionStatus)
  536. require.Equal(t, int64(0), u.Billing.StripeSubscriptionPaidUntil.Unix())
  537. require.Equal(t, int64(0), u.Billing.StripeSubscriptionCancelAt.Unix())
  538. // Verify that reservations were deleted
  539. r, err := s.userManager.Reservations("phil")
  540. require.Nil(t, err)
  541. require.Equal(t, 0, len(r))
  542. }
  543. func TestPayments_Subscription_Update_Different_Tier(t *testing.T) {
  544. stripeMock := &testStripeAPI{}
  545. defer stripeMock.AssertExpectations(t)
  546. c := newTestConfigWithAuthFile(t)
  547. c.StripeSecretKey = "secret key"
  548. c.StripeWebhookKey = "webhook key"
  549. s := newTestServer(t, c)
  550. s.stripe = stripeMock
  551. // Define how the mock should react
  552. stripeMock.
  553. On("GetSubscription", "sub_123").
  554. Return(&stripe.Subscription{
  555. ID: "sub_123",
  556. Items: &stripe.SubscriptionItemList{
  557. Data: []*stripe.SubscriptionItem{
  558. {
  559. ID: "someid_123",
  560. Price: &stripe.Price{ID: "price_123"},
  561. },
  562. },
  563. },
  564. }, nil)
  565. stripeMock.
  566. On("UpdateSubscription", "sub_123", &stripe.SubscriptionParams{
  567. CancelAtPeriodEnd: stripe.Bool(false),
  568. ProrationBehavior: stripe.String(string(stripe.SubscriptionSchedulePhaseProrationBehaviorAlwaysInvoice)),
  569. Items: []*stripe.SubscriptionItemsParams{
  570. {
  571. ID: stripe.String("someid_123"),
  572. Price: stripe.String("price_457"),
  573. },
  574. },
  575. }).
  576. Return(&stripe.Subscription{}, nil)
  577. // Create tier and user
  578. require.Nil(t, s.userManager.AddTier(&user.Tier{
  579. ID: "ti_123",
  580. Code: "pro",
  581. StripeMonthlyPriceID: "price_123",
  582. StripeYearlyPriceID: "price_124",
  583. }))
  584. require.Nil(t, s.userManager.AddTier(&user.Tier{
  585. ID: "ti_456",
  586. Code: "business",
  587. StripeMonthlyPriceID: "price_456",
  588. StripeYearlyPriceID: "price_457",
  589. }))
  590. require.Nil(t, s.userManager.AddUser("phil", "phil", user.RoleUser, false))
  591. require.Nil(t, s.userManager.ChangeTier("phil", "pro"))
  592. require.Nil(t, s.userManager.ChangeBilling("phil", &user.Billing{
  593. StripeCustomerID: "acct_123",
  594. StripeSubscriptionID: "sub_123",
  595. }))
  596. // Call endpoint to change subscription
  597. rr := request(t, s, "PUT", "/v1/account/billing/subscription", `{"tier":"business","interval":"year"}`, map[string]string{
  598. "Authorization": util.BasicAuth("phil", "phil"),
  599. })
  600. require.Equal(t, 200, rr.Code)
  601. }
  602. func TestPayments_Subscription_Delete_At_Period_End(t *testing.T) {
  603. stripeMock := &testStripeAPI{}
  604. defer stripeMock.AssertExpectations(t)
  605. c := newTestConfigWithAuthFile(t)
  606. c.StripeSecretKey = "secret key"
  607. c.StripeWebhookKey = "webhook key"
  608. s := newTestServer(t, c)
  609. s.stripe = stripeMock
  610. // Define how the mock should react
  611. stripeMock.
  612. On("UpdateSubscription", "sub_123", mock.MatchedBy(func(s *stripe.SubscriptionParams) bool {
  613. return *s.CancelAtPeriodEnd // Is true
  614. })).
  615. Return(&stripe.Subscription{}, nil)
  616. // Create user
  617. require.Nil(t, s.userManager.AddUser("phil", "phil", user.RoleUser, false))
  618. require.Nil(t, s.userManager.ChangeBilling("phil", &user.Billing{
  619. StripeCustomerID: "acct_123",
  620. StripeSubscriptionID: "sub_123",
  621. }))
  622. // Delete subscription
  623. rr := request(t, s, "DELETE", "/v1/account/billing/subscription", "", map[string]string{
  624. "Authorization": util.BasicAuth("phil", "phil"),
  625. })
  626. require.Equal(t, 200, rr.Code)
  627. }
  628. func TestPayments_CreatePortalSession(t *testing.T) {
  629. stripeMock := &testStripeAPI{}
  630. defer stripeMock.AssertExpectations(t)
  631. c := newTestConfigWithAuthFile(t)
  632. c.StripeSecretKey = "secret key"
  633. c.StripeWebhookKey = "webhook key"
  634. s := newTestServer(t, c)
  635. s.stripe = stripeMock
  636. // Define how the mock should react
  637. stripeMock.
  638. On("NewPortalSession", &stripe.BillingPortalSessionParams{
  639. Customer: stripe.String("acct_123"),
  640. ReturnURL: stripe.String(s.config.BaseURL),
  641. }).
  642. Return(&stripe.BillingPortalSession{
  643. URL: "https://billing.stripe.com/blablabla",
  644. }, nil)
  645. // Create user
  646. require.Nil(t, s.userManager.AddUser("phil", "phil", user.RoleUser, false))
  647. require.Nil(t, s.userManager.ChangeBilling("phil", &user.Billing{
  648. StripeCustomerID: "acct_123",
  649. StripeSubscriptionID: "sub_123",
  650. }))
  651. // Create portal session
  652. rr := request(t, s, "POST", "/v1/account/billing/portal", "", map[string]string{
  653. "Authorization": util.BasicAuth("phil", "phil"),
  654. })
  655. require.Equal(t, 200, rr.Code)
  656. ps, _ := util.UnmarshalJSON[apiAccountBillingPortalRedirectResponse](io.NopCloser(rr.Body))
  657. require.Equal(t, "https://billing.stripe.com/blablabla", ps.RedirectURL)
  658. }
  659. type testStripeAPI struct {
  660. mock.Mock
  661. }
  662. var _ stripeAPI = (*testStripeAPI)(nil)
  663. func (s *testStripeAPI) NewCheckoutSession(params *stripe.CheckoutSessionParams) (*stripe.CheckoutSession, error) {
  664. args := s.Called(params)
  665. return args.Get(0).(*stripe.CheckoutSession), args.Error(1)
  666. }
  667. func (s *testStripeAPI) NewPortalSession(params *stripe.BillingPortalSessionParams) (*stripe.BillingPortalSession, error) {
  668. args := s.Called(params)
  669. return args.Get(0).(*stripe.BillingPortalSession), args.Error(1)
  670. }
  671. func (s *testStripeAPI) ListPrices(params *stripe.PriceListParams) ([]*stripe.Price, error) {
  672. args := s.Called(params)
  673. return args.Get(0).([]*stripe.Price), args.Error(1)
  674. }
  675. func (s *testStripeAPI) GetCustomer(id string) (*stripe.Customer, error) {
  676. args := s.Called(id)
  677. return args.Get(0).(*stripe.Customer), args.Error(1)
  678. }
  679. func (s *testStripeAPI) GetSession(id string) (*stripe.CheckoutSession, error) {
  680. args := s.Called(id)
  681. return args.Get(0).(*stripe.CheckoutSession), args.Error(1)
  682. }
  683. func (s *testStripeAPI) GetSubscription(id string) (*stripe.Subscription, error) {
  684. args := s.Called(id)
  685. return args.Get(0).(*stripe.Subscription), args.Error(1)
  686. }
  687. func (s *testStripeAPI) UpdateCustomer(id string, params *stripe.CustomerParams) (*stripe.Customer, error) {
  688. args := s.Called(id, params)
  689. return args.Get(0).(*stripe.Customer), args.Error(1)
  690. }
  691. func (s *testStripeAPI) UpdateSubscription(id string, params *stripe.SubscriptionParams) (*stripe.Subscription, error) {
  692. args := s.Called(id, params)
  693. return args.Get(0).(*stripe.Subscription), args.Error(1)
  694. }
  695. func (s *testStripeAPI) CancelSubscription(id string) (*stripe.Subscription, error) {
  696. args := s.Called(id)
  697. return args.Get(0).(*stripe.Subscription), args.Error(1)
  698. }
  699. func (s *testStripeAPI) ConstructWebhookEvent(payload []byte, header string, secret string) (stripe.Event, error) {
  700. args := s.Called(payload, header, secret)
  701. return args.Get(0).(stripe.Event), args.Error(1)
  702. }
  703. func jsonToStripeEvent(t *testing.T, v string) stripe.Event {
  704. var e stripe.Event
  705. if err := json.Unmarshal([]byte(v), &e); err != nil {
  706. t.Fatal(err)
  707. }
  708. return e
  709. }
  710. const subscriptionUpdatedEventJSON = `
  711. {
  712. "type": "customer.subscription.updated",
  713. "data": {
  714. "object": {
  715. "id": "sub_1234",
  716. "customer": "acct_5555",
  717. "status": "active",
  718. "current_period_end": 1674268231,
  719. "cancel_at": 1674299999,
  720. "items": {
  721. "data": [
  722. {
  723. "price": {
  724. "id": "price_1234",
  725. "recurring": {
  726. "interval": "year"
  727. }
  728. }
  729. }
  730. ]
  731. }
  732. }
  733. }
  734. }`
  735. const subscriptionDeletedEventJSON = `
  736. {
  737. "type": "customer.subscription.deleted",
  738. "data": {
  739. "object": {
  740. "id": "sub_1234",
  741. "customer": "acct_5555",
  742. "status": "active",
  743. "current_period_end": 1674268231,
  744. "cancel_at": 1674299999,
  745. "items": {
  746. "data": [
  747. {
  748. "price": {
  749. "id": "price_1234",
  750. "recurring": {
  751. "interval": "month"
  752. }
  753. }
  754. }
  755. ]
  756. }
  757. }
  758. }
  759. }`