server_payments_test.go 29 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858
  1. //go:build !nopayments
  2. package server
  3. import (
  4. "encoding/json"
  5. "github.com/stretchr/testify/mock"
  6. "github.com/stretchr/testify/require"
  7. "github.com/stripe/stripe-go/v74"
  8. "golang.org/x/time/rate"
  9. "heckel.io/ntfy/v2/user"
  10. "heckel.io/ntfy/v2/util"
  11. "io"
  12. "net/netip"
  13. "path/filepath"
  14. "strings"
  15. "sync"
  16. "testing"
  17. "time"
  18. )
  19. func TestPayments_Tiers(t *testing.T) {
  20. stripeMock := &testStripeAPI{}
  21. defer stripeMock.AssertExpectations(t)
  22. c := newTestConfigWithAuthFile(t)
  23. c.StripeSecretKey = "secret key"
  24. c.StripeWebhookKey = "webhook key"
  25. c.VisitorRequestLimitReplenish = 12 * time.Hour
  26. c.CacheDuration = 13 * time.Hour
  27. c.AttachmentFileSizeLimit = 111
  28. c.VisitorAttachmentTotalSizeLimit = 222
  29. c.AttachmentExpiryDuration = 123 * time.Second
  30. s := newTestServer(t, c)
  31. s.stripe = stripeMock
  32. // Define how the mock should react
  33. stripeMock.
  34. On("ListPrices", mock.Anything).
  35. Return([]*stripe.Price{
  36. {ID: "price_123", UnitAmount: 500},
  37. {ID: "price_124", UnitAmount: 5000},
  38. {ID: "price_456", UnitAmount: 1000},
  39. {ID: "price_457", UnitAmount: 10000},
  40. {ID: "price_999", UnitAmount: 9999},
  41. }, nil)
  42. // Create tiers
  43. require.Nil(t, s.userManager.AddTier(&user.Tier{
  44. ID: "ti_1",
  45. Code: "admin",
  46. Name: "Admin",
  47. }))
  48. require.Nil(t, s.userManager.AddTier(&user.Tier{
  49. ID: "ti_123",
  50. Code: "pro",
  51. Name: "Pro",
  52. MessageLimit: 1000,
  53. MessageExpiryDuration: time.Hour,
  54. EmailLimit: 123,
  55. ReservationLimit: 777,
  56. AttachmentFileSizeLimit: 999,
  57. AttachmentTotalSizeLimit: 888,
  58. AttachmentExpiryDuration: time.Minute,
  59. StripeMonthlyPriceID: "price_123",
  60. StripeYearlyPriceID: "price_124",
  61. }))
  62. require.Nil(t, s.userManager.AddTier(&user.Tier{
  63. ID: "ti_444",
  64. Code: "business",
  65. Name: "Business",
  66. MessageLimit: 2000,
  67. MessageExpiryDuration: 10 * time.Hour,
  68. EmailLimit: 123123,
  69. ReservationLimit: 777333,
  70. AttachmentFileSizeLimit: 999111,
  71. AttachmentTotalSizeLimit: 888111,
  72. AttachmentExpiryDuration: time.Hour,
  73. StripeMonthlyPriceID: "price_456",
  74. StripeYearlyPriceID: "price_457",
  75. }))
  76. response := request(t, s, "GET", "/v1/tiers", "", nil)
  77. require.Equal(t, 200, response.Code)
  78. var tiers []apiAccountBillingTier
  79. require.Nil(t, json.NewDecoder(response.Body).Decode(&tiers))
  80. require.Equal(t, 3, len(tiers))
  81. // Free tier
  82. tier := tiers[0]
  83. require.Equal(t, "", tier.Code)
  84. require.Equal(t, "", tier.Name)
  85. require.Equal(t, "ip", tier.Limits.Basis)
  86. require.Equal(t, int64(0), tier.Limits.Reservations)
  87. require.Equal(t, int64(2), tier.Limits.Messages) // :-(
  88. require.Equal(t, int64(13*3600), tier.Limits.MessagesExpiryDuration)
  89. require.Equal(t, int64(24), tier.Limits.Emails)
  90. require.Equal(t, int64(111), tier.Limits.AttachmentFileSize)
  91. require.Equal(t, int64(222), tier.Limits.AttachmentTotalSize)
  92. require.Equal(t, int64(123), tier.Limits.AttachmentExpiryDuration)
  93. // Admin tier is not included, because it is not paid!
  94. tier = tiers[1]
  95. require.Equal(t, "pro", tier.Code)
  96. require.Equal(t, "Pro", tier.Name)
  97. require.Equal(t, "tier", tier.Limits.Basis)
  98. require.Equal(t, int64(500), tier.Prices.Month)
  99. require.Equal(t, int64(5000), tier.Prices.Year)
  100. require.Equal(t, int64(777), tier.Limits.Reservations)
  101. require.Equal(t, int64(1000), tier.Limits.Messages)
  102. require.Equal(t, int64(3600), tier.Limits.MessagesExpiryDuration)
  103. require.Equal(t, int64(123), tier.Limits.Emails)
  104. require.Equal(t, int64(999), tier.Limits.AttachmentFileSize)
  105. require.Equal(t, int64(888), tier.Limits.AttachmentTotalSize)
  106. require.Equal(t, int64(60), tier.Limits.AttachmentExpiryDuration)
  107. tier = tiers[2]
  108. require.Equal(t, "business", tier.Code)
  109. require.Equal(t, "Business", tier.Name)
  110. require.Equal(t, int64(1000), tier.Prices.Month)
  111. require.Equal(t, int64(10000), tier.Prices.Year)
  112. require.Equal(t, "tier", tier.Limits.Basis)
  113. require.Equal(t, int64(777333), tier.Limits.Reservations)
  114. require.Equal(t, int64(2000), tier.Limits.Messages)
  115. require.Equal(t, int64(36000), tier.Limits.MessagesExpiryDuration)
  116. require.Equal(t, int64(123123), tier.Limits.Emails)
  117. require.Equal(t, int64(999111), tier.Limits.AttachmentFileSize)
  118. require.Equal(t, int64(888111), tier.Limits.AttachmentTotalSize)
  119. require.Equal(t, int64(3600), tier.Limits.AttachmentExpiryDuration)
  120. }
  121. func TestPayments_SubscriptionCreate_NotAStripeCustomer_Success(t *testing.T) {
  122. stripeMock := &testStripeAPI{}
  123. defer stripeMock.AssertExpectations(t)
  124. c := newTestConfigWithAuthFile(t)
  125. c.StripeSecretKey = "secret key"
  126. c.StripeWebhookKey = "webhook key"
  127. s := newTestServer(t, c)
  128. s.stripe = stripeMock
  129. // Define how the mock should react
  130. stripeMock.
  131. On("NewCheckoutSession", mock.Anything).
  132. Return(&stripe.CheckoutSession{URL: "https://billing.stripe.com/abc/def"}, nil)
  133. // Create tier and user
  134. require.Nil(t, s.userManager.AddTier(&user.Tier{
  135. ID: "ti_123",
  136. Code: "pro",
  137. StripeMonthlyPriceID: "price_123",
  138. }))
  139. require.Nil(t, s.userManager.AddUser("phil", "phil", user.RoleUser, false))
  140. // Create subscription
  141. response := request(t, s, "POST", "/v1/account/billing/subscription", `{"tier": "pro", "interval": "month"}`, map[string]string{
  142. "Authorization": util.BasicAuth("phil", "phil"),
  143. })
  144. require.Equal(t, 200, response.Code)
  145. redirectResponse, err := util.UnmarshalJSON[apiAccountBillingSubscriptionCreateResponse](io.NopCloser(response.Body))
  146. require.Nil(t, err)
  147. require.Equal(t, "https://billing.stripe.com/abc/def", redirectResponse.RedirectURL)
  148. }
  149. func TestPayments_SubscriptionCreate_StripeCustomer_Success(t *testing.T) {
  150. stripeMock := &testStripeAPI{}
  151. defer stripeMock.AssertExpectations(t)
  152. c := newTestConfigWithAuthFile(t)
  153. c.StripeSecretKey = "secret key"
  154. c.StripeWebhookKey = "webhook key"
  155. s := newTestServer(t, c)
  156. s.stripe = stripeMock
  157. // Define how the mock should react
  158. stripeMock.
  159. On("GetCustomer", "acct_123").
  160. Return(&stripe.Customer{Subscriptions: &stripe.SubscriptionList{}}, nil)
  161. stripeMock.
  162. On("NewCheckoutSession", mock.Anything).
  163. Return(&stripe.CheckoutSession{URL: "https://billing.stripe.com/abc/def"}, nil)
  164. // Create tier and user
  165. require.Nil(t, s.userManager.AddTier(&user.Tier{
  166. ID: "ti_123",
  167. Code: "pro",
  168. StripeMonthlyPriceID: "price_123",
  169. }))
  170. require.Nil(t, s.userManager.AddUser("phil", "phil", user.RoleUser, false))
  171. u, err := s.userManager.User("phil")
  172. require.Nil(t, err)
  173. billing := &user.Billing{
  174. StripeCustomerID: "acct_123",
  175. }
  176. require.Nil(t, s.userManager.ChangeBilling(u.Name, billing))
  177. // Create subscription
  178. response := request(t, s, "POST", "/v1/account/billing/subscription", `{"tier": "pro", "interval": "month"}`, map[string]string{
  179. "Authorization": util.BasicAuth("phil", "phil"),
  180. })
  181. require.Equal(t, 200, response.Code)
  182. redirectResponse, err := util.UnmarshalJSON[apiAccountBillingSubscriptionCreateResponse](io.NopCloser(response.Body))
  183. require.Nil(t, err)
  184. require.Equal(t, "https://billing.stripe.com/abc/def", redirectResponse.RedirectURL)
  185. }
  186. func TestPayments_AccountDelete_Cancels_Subscription(t *testing.T) {
  187. stripeMock := &testStripeAPI{}
  188. defer stripeMock.AssertExpectations(t)
  189. c := newTestConfigWithAuthFile(t)
  190. c.EnableSignup = true
  191. c.StripeSecretKey = "secret key"
  192. c.StripeWebhookKey = "webhook key"
  193. s := newTestServer(t, c)
  194. s.stripe = stripeMock
  195. // Define how the mock should react
  196. stripeMock.
  197. On("CancelSubscription", "sub_123").
  198. Return(&stripe.Subscription{}, nil)
  199. // Create tier and user
  200. require.Nil(t, s.userManager.AddTier(&user.Tier{
  201. ID: "ti_123",
  202. Code: "pro",
  203. StripeMonthlyPriceID: "price_123",
  204. }))
  205. require.Nil(t, s.userManager.AddUser("phil", "phil", user.RoleUser, false))
  206. u, err := s.userManager.User("phil")
  207. require.Nil(t, err)
  208. billing := &user.Billing{
  209. StripeCustomerID: "acct_123",
  210. StripeSubscriptionID: "sub_123",
  211. }
  212. require.Nil(t, s.userManager.ChangeBilling(u.Name, billing))
  213. // Delete account
  214. rr := request(t, s, "DELETE", "/v1/account", `{"password": "phil"}`, map[string]string{
  215. "Authorization": util.BasicAuth("phil", "phil"),
  216. })
  217. require.Equal(t, 200, rr.Code)
  218. rr = request(t, s, "GET", "/v1/account", "", map[string]string{
  219. "Authorization": util.BasicAuth("phil", "mypass"),
  220. })
  221. require.Equal(t, 401, rr.Code)
  222. }
  223. func TestPayments_Checkout_Success_And_Increase_Rate_Limits_Reset_Visitor(t *testing.T) {
  224. // This test is too overloaded, but it's also a great end-to-end a test.
  225. //
  226. // It tests:
  227. // - A successful checkout flow (not a paying customer -> paying customer)
  228. // - Tier-changes reset the rate limits for the user
  229. // - The request limits for tier-less user and a tier-user
  230. // - The message limits for a tier-user
  231. stripeMock := &testStripeAPI{}
  232. defer stripeMock.AssertExpectations(t)
  233. c := newTestConfigWithAuthFile(t)
  234. c.StripeSecretKey = "secret key"
  235. c.StripeWebhookKey = "webhook key"
  236. c.VisitorRequestLimitBurst = 5
  237. c.VisitorRequestLimitReplenish = time.Hour
  238. c.CacheBatchSize = 500
  239. c.CacheBatchTimeout = time.Second
  240. s := newTestServer(t, c)
  241. s.stripe = stripeMock
  242. // Create a user with a Stripe subscription and 3 reservations
  243. require.Nil(t, s.userManager.AddTier(&user.Tier{
  244. ID: "ti_123",
  245. Code: "starter",
  246. StripeMonthlyPriceID: "price_1234",
  247. ReservationLimit: 1,
  248. MessageLimit: 220, // 220 * 5% = 11 requests before rate limiting kicks in
  249. MessageExpiryDuration: time.Hour,
  250. }))
  251. require.Nil(t, s.userManager.AddUser("phil", "phil", user.RoleUser, false)) // No tier
  252. u, err := s.userManager.User("phil")
  253. require.Nil(t, err)
  254. // Define how the mock should react
  255. stripeMock.
  256. On("GetSession", "SOMETOKEN").
  257. Return(&stripe.CheckoutSession{
  258. ClientReferenceID: u.ID, // ntfy user ID
  259. Customer: &stripe.Customer{
  260. ID: "acct_5555",
  261. },
  262. Subscription: &stripe.Subscription{
  263. ID: "sub_1234",
  264. },
  265. }, nil)
  266. stripeMock.
  267. On("GetSubscription", "sub_1234").
  268. Return(&stripe.Subscription{
  269. ID: "sub_1234",
  270. Status: stripe.SubscriptionStatusActive,
  271. CurrentPeriodEnd: 123456789,
  272. CancelAt: 0,
  273. Items: &stripe.SubscriptionItemList{
  274. Data: []*stripe.SubscriptionItem{
  275. {
  276. Price: &stripe.Price{
  277. ID: "price_1234",
  278. Recurring: &stripe.PriceRecurring{
  279. Interval: stripe.PriceRecurringIntervalMonth,
  280. },
  281. },
  282. },
  283. },
  284. },
  285. }, nil)
  286. stripeMock.
  287. On("UpdateCustomer", "acct_5555", &stripe.CustomerParams{
  288. Params: stripe.Params{
  289. Metadata: map[string]string{
  290. "user_id": u.ID,
  291. "user_name": u.Name,
  292. },
  293. },
  294. }).
  295. Return(&stripe.Customer{}, nil)
  296. // Send messages until rate limit of free tier is hit
  297. for i := 0; i < 5; i++ {
  298. rr := request(t, s, "PUT", "/mytopic", "some message", map[string]string{
  299. "Authorization": util.BasicAuth("phil", "phil"),
  300. })
  301. require.Equal(t, 200, rr.Code)
  302. }
  303. rr := request(t, s, "PUT", "/mytopic", "some message", map[string]string{
  304. "Authorization": util.BasicAuth("phil", "phil"),
  305. })
  306. require.Equal(t, 429, rr.Code)
  307. // Verify some "before-stats"
  308. u, err = s.userManager.User("phil")
  309. require.Nil(t, err)
  310. require.Nil(t, u.Tier)
  311. require.Equal(t, "", u.Billing.StripeCustomerID)
  312. require.Equal(t, "", u.Billing.StripeSubscriptionID)
  313. require.Equal(t, stripe.SubscriptionStatus(""), u.Billing.StripeSubscriptionStatus)
  314. require.Equal(t, stripe.PriceRecurringInterval(""), u.Billing.StripeSubscriptionInterval)
  315. require.Equal(t, int64(0), u.Billing.StripeSubscriptionPaidUntil.Unix())
  316. require.Equal(t, int64(0), u.Billing.StripeSubscriptionCancelAt.Unix())
  317. require.Equal(t, int64(0), u.Stats.Messages) // Messages and emails are not persisted for no-tier users!
  318. require.Equal(t, int64(0), u.Stats.Emails)
  319. // Simulate Stripe success return URL call (no user context)
  320. rr = request(t, s, "GET", "/v1/account/billing/subscription/success/SOMETOKEN", "", nil)
  321. require.Equal(t, 303, rr.Code)
  322. // Verify that database columns were updated
  323. u, err = s.userManager.User("phil")
  324. require.Nil(t, err)
  325. require.Equal(t, "starter", u.Tier.Code) // Not "pro"
  326. require.Equal(t, "acct_5555", u.Billing.StripeCustomerID)
  327. require.Equal(t, "sub_1234", u.Billing.StripeSubscriptionID)
  328. require.Equal(t, stripe.SubscriptionStatusActive, u.Billing.StripeSubscriptionStatus)
  329. require.Equal(t, stripe.PriceRecurringIntervalMonth, u.Billing.StripeSubscriptionInterval)
  330. require.Equal(t, int64(123456789), u.Billing.StripeSubscriptionPaidUntil.Unix())
  331. require.Equal(t, int64(0), u.Billing.StripeSubscriptionCancelAt.Unix())
  332. require.Equal(t, int64(0), u.Stats.Messages)
  333. require.Equal(t, int64(0), u.Stats.Emails)
  334. // Now for the fun part: Verify that new rate limits are immediately applied
  335. // This only tests the request limiter, which kicks in before the message limiter.
  336. for i := 0; i < 11; i++ {
  337. rr := request(t, s, "PUT", "/mytopic", "some message", map[string]string{
  338. "Authorization": util.BasicAuth("phil", "phil"),
  339. })
  340. require.Equal(t, 200, rr.Code, "failed on iteration %d", i)
  341. }
  342. rr = request(t, s, "PUT", "/mytopic", "some message", map[string]string{
  343. "Authorization": util.BasicAuth("phil", "phil"),
  344. })
  345. require.Equal(t, 429, rr.Code)
  346. // Now let's test the message limiter by faking a ridiculously generous rate limiter
  347. v := s.visitor(netip.MustParseAddr("9.9.9.9"), u)
  348. v.requestLimiter = rate.NewLimiter(rate.Every(time.Millisecond), 1000000)
  349. var wg sync.WaitGroup
  350. for i := 0; i < 209; i++ {
  351. wg.Add(1)
  352. go func(i int) {
  353. defer wg.Done()
  354. rr := request(t, s, "PUT", "/mytopic", "some message", map[string]string{
  355. "Authorization": util.BasicAuth("phil", "phil"),
  356. })
  357. require.Equal(t, 200, rr.Code, "Failed on %d", i)
  358. }(i)
  359. }
  360. wg.Wait()
  361. rr = request(t, s, "PUT", "/mytopic", "some message", map[string]string{
  362. "Authorization": util.BasicAuth("phil", "phil"),
  363. })
  364. require.Equal(t, 429, rr.Code)
  365. // And now let's cross-check that the stats are correct too
  366. rr = request(t, s, "GET", "/v1/account", "", map[string]string{
  367. "Authorization": util.BasicAuth("phil", "phil"),
  368. })
  369. require.Equal(t, 200, rr.Code)
  370. account, _ := util.UnmarshalJSON[apiAccountResponse](io.NopCloser(rr.Body))
  371. require.Equal(t, int64(220), account.Limits.Messages)
  372. require.Equal(t, int64(220), account.Stats.Messages)
  373. require.Equal(t, int64(0), account.Stats.MessagesRemaining)
  374. }
  375. func TestPayments_Webhook_Subscription_Updated_Downgrade_From_PastDue_To_Active(t *testing.T) {
  376. t.Parallel()
  377. // This tests incoming webhooks from Stripe to update a subscription:
  378. // - All Stripe columns are updated in the user table
  379. // - When downgrading, excess reservations are deleted, including messages and attachments in
  380. // the corresponding topics
  381. stripeMock := &testStripeAPI{}
  382. defer stripeMock.AssertExpectations(t)
  383. c := newTestConfigWithAuthFile(t)
  384. c.StripeSecretKey = "secret key"
  385. c.StripeWebhookKey = "webhook key"
  386. s := newTestServer(t, c)
  387. s.stripe = stripeMock
  388. // Define how the mock should react
  389. stripeMock.
  390. On("ConstructWebhookEvent", mock.Anything, "stripe signature", "webhook key").
  391. Return(jsonToStripeEvent(t, subscriptionUpdatedEventJSON), nil)
  392. // Create a user with a Stripe subscription and 3 reservations
  393. require.Nil(t, s.userManager.AddTier(&user.Tier{
  394. ID: "ti_1",
  395. Code: "starter",
  396. StripeMonthlyPriceID: "price_1234", // !
  397. ReservationLimit: 1, // !
  398. MessageLimit: 100,
  399. MessageExpiryDuration: time.Hour,
  400. AttachmentExpiryDuration: time.Hour,
  401. AttachmentFileSizeLimit: 1000000,
  402. AttachmentTotalSizeLimit: 1000000,
  403. AttachmentBandwidthLimit: 1000000,
  404. }))
  405. require.Nil(t, s.userManager.AddTier(&user.Tier{
  406. ID: "ti_2",
  407. Code: "pro",
  408. StripeMonthlyPriceID: "price_1111", // !
  409. ReservationLimit: 3, // !
  410. MessageLimit: 200,
  411. MessageExpiryDuration: time.Hour,
  412. AttachmentExpiryDuration: time.Hour,
  413. AttachmentFileSizeLimit: 1000000,
  414. AttachmentTotalSizeLimit: 1000000,
  415. AttachmentBandwidthLimit: 1000000,
  416. }))
  417. require.Nil(t, s.userManager.AddUser("phil", "phil", user.RoleUser, false))
  418. require.Nil(t, s.userManager.ChangeTier("phil", "pro"))
  419. require.Nil(t, s.userManager.AddReservation("phil", "atopic", user.PermissionDenyAll))
  420. require.Nil(t, s.userManager.AddReservation("phil", "ztopic", user.PermissionDenyAll))
  421. // Add billing details
  422. u, err := s.userManager.User("phil")
  423. require.Nil(t, err)
  424. billing := &user.Billing{
  425. StripeCustomerID: "acct_5555",
  426. StripeSubscriptionID: "sub_1234",
  427. StripeSubscriptionStatus: stripe.SubscriptionStatusPastDue,
  428. StripeSubscriptionInterval: stripe.PriceRecurringIntervalMonth,
  429. StripeSubscriptionPaidUntil: time.Unix(123, 0),
  430. StripeSubscriptionCancelAt: time.Unix(456, 0),
  431. }
  432. require.Nil(t, s.userManager.ChangeBilling(u.Name, billing))
  433. // Add some messages to "atopic" and "ztopic", everything in "ztopic" will be deleted
  434. rr := request(t, s, "PUT", "/atopic", "some aaa message", map[string]string{
  435. "Authorization": util.BasicAuth("phil", "phil"),
  436. })
  437. require.Equal(t, 200, rr.Code)
  438. rr = request(t, s, "PUT", "/atopic", strings.Repeat("a", 5000), map[string]string{
  439. "Authorization": util.BasicAuth("phil", "phil"),
  440. })
  441. require.Equal(t, 200, rr.Code)
  442. a2 := toMessage(t, rr.Body.String())
  443. require.FileExists(t, filepath.Join(s.config.AttachmentCacheDir, a2.ID))
  444. rr = request(t, s, "PUT", "/ztopic", "some zzz message", map[string]string{
  445. "Authorization": util.BasicAuth("phil", "phil"),
  446. })
  447. require.Equal(t, 200, rr.Code)
  448. rr = request(t, s, "PUT", "/ztopic", strings.Repeat("z", 5000), map[string]string{
  449. "Authorization": util.BasicAuth("phil", "phil"),
  450. })
  451. require.Equal(t, 200, rr.Code)
  452. z2 := toMessage(t, rr.Body.String())
  453. require.FileExists(t, filepath.Join(s.config.AttachmentCacheDir, z2.ID))
  454. // Call the webhook: This does all the magic
  455. rr = request(t, s, "POST", "/v1/account/billing/webhook", "dummy", map[string]string{
  456. "Stripe-Signature": "stripe signature",
  457. })
  458. require.Equal(t, 200, rr.Code)
  459. // Verify that database columns were updated
  460. u, err = s.userManager.User("phil")
  461. require.Nil(t, err)
  462. require.Equal(t, "starter", u.Tier.Code) // Not "pro"
  463. require.Equal(t, "acct_5555", u.Billing.StripeCustomerID)
  464. require.Equal(t, "sub_1234", u.Billing.StripeSubscriptionID)
  465. require.Equal(t, stripe.SubscriptionStatusActive, u.Billing.StripeSubscriptionStatus) // Not "past_due"
  466. require.Equal(t, stripe.PriceRecurringIntervalYear, u.Billing.StripeSubscriptionInterval) // Not "month"
  467. require.Equal(t, int64(1674268231), u.Billing.StripeSubscriptionPaidUntil.Unix()) // Updated
  468. require.Equal(t, int64(1674299999), u.Billing.StripeSubscriptionCancelAt.Unix()) // Updated
  469. // Verify that reservations were deleted
  470. r, err := s.userManager.Reservations("phil")
  471. require.Nil(t, err)
  472. require.Equal(t, 1, len(r)) // "ztopic" reservation was deleted
  473. require.Equal(t, "atopic", r[0].Topic)
  474. // Verify that messages and attachments were deleted
  475. time.Sleep(time.Second)
  476. s.execManager()
  477. ms, err := s.messageCache.Messages("atopic", sinceAllMessages, false)
  478. require.Nil(t, err)
  479. require.Equal(t, 2, len(ms))
  480. require.FileExists(t, filepath.Join(s.config.AttachmentCacheDir, a2.ID))
  481. ms, err = s.messageCache.Messages("ztopic", sinceAllMessages, false)
  482. require.Nil(t, err)
  483. require.Equal(t, 0, len(ms))
  484. require.NoFileExists(t, filepath.Join(s.config.AttachmentCacheDir, z2.ID))
  485. }
  486. func TestPayments_Webhook_Subscription_Deleted(t *testing.T) {
  487. // This tests incoming webhooks from Stripe to delete a subscription. It verifies that the database is
  488. // updated (all Stripe fields are deleted, and the tier is removed).
  489. //
  490. // It doesn't fully test the message/attachment deletion. That is tested above in the subscription update call.
  491. stripeMock := &testStripeAPI{}
  492. defer stripeMock.AssertExpectations(t)
  493. c := newTestConfigWithAuthFile(t)
  494. c.StripeSecretKey = "secret key"
  495. c.StripeWebhookKey = "webhook key"
  496. s := newTestServer(t, c)
  497. s.stripe = stripeMock
  498. // Define how the mock should react
  499. stripeMock.
  500. On("ConstructWebhookEvent", mock.Anything, "stripe signature", "webhook key").
  501. Return(jsonToStripeEvent(t, subscriptionDeletedEventJSON), nil)
  502. // Create a user with a Stripe subscription and 3 reservations
  503. require.Nil(t, s.userManager.AddTier(&user.Tier{
  504. ID: "ti_1",
  505. Code: "pro",
  506. StripeMonthlyPriceID: "price_1234",
  507. ReservationLimit: 1,
  508. }))
  509. require.Nil(t, s.userManager.AddUser("phil", "phil", user.RoleUser, false))
  510. require.Nil(t, s.userManager.ChangeTier("phil", "pro"))
  511. require.Nil(t, s.userManager.AddReservation("phil", "atopic", user.PermissionDenyAll))
  512. // Add billing details
  513. u, err := s.userManager.User("phil")
  514. require.Nil(t, err)
  515. require.Nil(t, s.userManager.ChangeBilling(u.Name, &user.Billing{
  516. StripeCustomerID: "acct_5555",
  517. StripeSubscriptionID: "sub_1234",
  518. StripeSubscriptionStatus: stripe.SubscriptionStatusPastDue,
  519. StripeSubscriptionInterval: stripe.PriceRecurringIntervalMonth,
  520. StripeSubscriptionPaidUntil: time.Unix(123, 0),
  521. StripeSubscriptionCancelAt: time.Unix(0, 0),
  522. }))
  523. // Call the webhook: This does all the magic
  524. rr := request(t, s, "POST", "/v1/account/billing/webhook", "dummy", map[string]string{
  525. "Stripe-Signature": "stripe signature",
  526. })
  527. require.Equal(t, 200, rr.Code)
  528. // Verify that database columns were updated
  529. u, err = s.userManager.User("phil")
  530. require.Nil(t, err)
  531. require.Nil(t, u.Tier)
  532. require.Equal(t, "acct_5555", u.Billing.StripeCustomerID)
  533. require.Equal(t, "", u.Billing.StripeSubscriptionID)
  534. require.Equal(t, stripe.SubscriptionStatus(""), u.Billing.StripeSubscriptionStatus)
  535. require.Equal(t, int64(0), u.Billing.StripeSubscriptionPaidUntil.Unix())
  536. require.Equal(t, int64(0), u.Billing.StripeSubscriptionCancelAt.Unix())
  537. // Verify that reservations were deleted
  538. r, err := s.userManager.Reservations("phil")
  539. require.Nil(t, err)
  540. require.Equal(t, 0, len(r))
  541. }
  542. func TestPayments_Subscription_Update_Different_Tier(t *testing.T) {
  543. stripeMock := &testStripeAPI{}
  544. defer stripeMock.AssertExpectations(t)
  545. c := newTestConfigWithAuthFile(t)
  546. c.StripeSecretKey = "secret key"
  547. c.StripeWebhookKey = "webhook key"
  548. s := newTestServer(t, c)
  549. s.stripe = stripeMock
  550. // Define how the mock should react
  551. stripeMock.
  552. On("GetSubscription", "sub_123").
  553. Return(&stripe.Subscription{
  554. ID: "sub_123",
  555. Items: &stripe.SubscriptionItemList{
  556. Data: []*stripe.SubscriptionItem{
  557. {
  558. ID: "someid_123",
  559. Price: &stripe.Price{ID: "price_123"},
  560. },
  561. },
  562. },
  563. }, nil)
  564. stripeMock.
  565. On("UpdateSubscription", "sub_123", &stripe.SubscriptionParams{
  566. CancelAtPeriodEnd: stripe.Bool(false),
  567. ProrationBehavior: stripe.String(string(stripe.SubscriptionSchedulePhaseProrationBehaviorAlwaysInvoice)),
  568. Items: []*stripe.SubscriptionItemsParams{
  569. {
  570. ID: stripe.String("someid_123"),
  571. Price: stripe.String("price_457"),
  572. },
  573. },
  574. }).
  575. Return(&stripe.Subscription{}, nil)
  576. // Create tier and user
  577. require.Nil(t, s.userManager.AddTier(&user.Tier{
  578. ID: "ti_123",
  579. Code: "pro",
  580. StripeMonthlyPriceID: "price_123",
  581. StripeYearlyPriceID: "price_124",
  582. }))
  583. require.Nil(t, s.userManager.AddTier(&user.Tier{
  584. ID: "ti_456",
  585. Code: "business",
  586. StripeMonthlyPriceID: "price_456",
  587. StripeYearlyPriceID: "price_457",
  588. }))
  589. require.Nil(t, s.userManager.AddUser("phil", "phil", user.RoleUser, false))
  590. require.Nil(t, s.userManager.ChangeTier("phil", "pro"))
  591. require.Nil(t, s.userManager.ChangeBilling("phil", &user.Billing{
  592. StripeCustomerID: "acct_123",
  593. StripeSubscriptionID: "sub_123",
  594. }))
  595. // Call endpoint to change subscription
  596. rr := request(t, s, "PUT", "/v1/account/billing/subscription", `{"tier":"business","interval":"year"}`, map[string]string{
  597. "Authorization": util.BasicAuth("phil", "phil"),
  598. })
  599. require.Equal(t, 200, rr.Code)
  600. }
  601. func TestPayments_Subscription_Delete_At_Period_End(t *testing.T) {
  602. stripeMock := &testStripeAPI{}
  603. defer stripeMock.AssertExpectations(t)
  604. c := newTestConfigWithAuthFile(t)
  605. c.StripeSecretKey = "secret key"
  606. c.StripeWebhookKey = "webhook key"
  607. s := newTestServer(t, c)
  608. s.stripe = stripeMock
  609. // Define how the mock should react
  610. stripeMock.
  611. On("UpdateSubscription", "sub_123", mock.MatchedBy(func(s *stripe.SubscriptionParams) bool {
  612. return *s.CancelAtPeriodEnd // Is true
  613. })).
  614. Return(&stripe.Subscription{}, nil)
  615. // Create user
  616. require.Nil(t, s.userManager.AddUser("phil", "phil", user.RoleUser, false))
  617. require.Nil(t, s.userManager.ChangeBilling("phil", &user.Billing{
  618. StripeCustomerID: "acct_123",
  619. StripeSubscriptionID: "sub_123",
  620. }))
  621. // Delete subscription
  622. rr := request(t, s, "DELETE", "/v1/account/billing/subscription", "", map[string]string{
  623. "Authorization": util.BasicAuth("phil", "phil"),
  624. })
  625. require.Equal(t, 200, rr.Code)
  626. }
  627. func TestPayments_CreatePortalSession(t *testing.T) {
  628. stripeMock := &testStripeAPI{}
  629. defer stripeMock.AssertExpectations(t)
  630. c := newTestConfigWithAuthFile(t)
  631. c.StripeSecretKey = "secret key"
  632. c.StripeWebhookKey = "webhook key"
  633. s := newTestServer(t, c)
  634. s.stripe = stripeMock
  635. // Define how the mock should react
  636. stripeMock.
  637. On("NewPortalSession", &stripe.BillingPortalSessionParams{
  638. Customer: stripe.String("acct_123"),
  639. ReturnURL: stripe.String(s.config.BaseURL),
  640. }).
  641. Return(&stripe.BillingPortalSession{
  642. URL: "https://billing.stripe.com/blablabla",
  643. }, nil)
  644. // Create user
  645. require.Nil(t, s.userManager.AddUser("phil", "phil", user.RoleUser, false))
  646. require.Nil(t, s.userManager.ChangeBilling("phil", &user.Billing{
  647. StripeCustomerID: "acct_123",
  648. StripeSubscriptionID: "sub_123",
  649. }))
  650. // Create portal session
  651. rr := request(t, s, "POST", "/v1/account/billing/portal", "", map[string]string{
  652. "Authorization": util.BasicAuth("phil", "phil"),
  653. })
  654. require.Equal(t, 200, rr.Code)
  655. ps, _ := util.UnmarshalJSON[apiAccountBillingPortalRedirectResponse](io.NopCloser(rr.Body))
  656. require.Equal(t, "https://billing.stripe.com/blablabla", ps.RedirectURL)
  657. }
  658. type testStripeAPI struct {
  659. mock.Mock
  660. }
  661. var _ stripeAPI = (*testStripeAPI)(nil)
  662. func (s *testStripeAPI) NewCheckoutSession(params *stripe.CheckoutSessionParams) (*stripe.CheckoutSession, error) {
  663. args := s.Called(params)
  664. return args.Get(0).(*stripe.CheckoutSession), args.Error(1)
  665. }
  666. func (s *testStripeAPI) NewPortalSession(params *stripe.BillingPortalSessionParams) (*stripe.BillingPortalSession, error) {
  667. args := s.Called(params)
  668. return args.Get(0).(*stripe.BillingPortalSession), args.Error(1)
  669. }
  670. func (s *testStripeAPI) ListPrices(params *stripe.PriceListParams) ([]*stripe.Price, error) {
  671. args := s.Called(params)
  672. return args.Get(0).([]*stripe.Price), args.Error(1)
  673. }
  674. func (s *testStripeAPI) GetCustomer(id string) (*stripe.Customer, error) {
  675. args := s.Called(id)
  676. return args.Get(0).(*stripe.Customer), args.Error(1)
  677. }
  678. func (s *testStripeAPI) GetSession(id string) (*stripe.CheckoutSession, error) {
  679. args := s.Called(id)
  680. return args.Get(0).(*stripe.CheckoutSession), args.Error(1)
  681. }
  682. func (s *testStripeAPI) GetSubscription(id string) (*stripe.Subscription, error) {
  683. args := s.Called(id)
  684. return args.Get(0).(*stripe.Subscription), args.Error(1)
  685. }
  686. func (s *testStripeAPI) UpdateCustomer(id string, params *stripe.CustomerParams) (*stripe.Customer, error) {
  687. args := s.Called(id, params)
  688. return args.Get(0).(*stripe.Customer), args.Error(1)
  689. }
  690. func (s *testStripeAPI) UpdateSubscription(id string, params *stripe.SubscriptionParams) (*stripe.Subscription, error) {
  691. args := s.Called(id, params)
  692. return args.Get(0).(*stripe.Subscription), args.Error(1)
  693. }
  694. func (s *testStripeAPI) CancelSubscription(id string) (*stripe.Subscription, error) {
  695. args := s.Called(id)
  696. return args.Get(0).(*stripe.Subscription), args.Error(1)
  697. }
  698. func (s *testStripeAPI) ConstructWebhookEvent(payload []byte, header string, secret string) (stripe.Event, error) {
  699. args := s.Called(payload, header, secret)
  700. return args.Get(0).(stripe.Event), args.Error(1)
  701. }
  702. func jsonToStripeEvent(t *testing.T, v string) stripe.Event {
  703. var e stripe.Event
  704. if err := json.Unmarshal([]byte(v), &e); err != nil {
  705. t.Fatal(err)
  706. }
  707. return e
  708. }
  709. const subscriptionUpdatedEventJSON = `
  710. {
  711. "type": "customer.subscription.updated",
  712. "data": {
  713. "object": {
  714. "id": "sub_1234",
  715. "customer": "acct_5555",
  716. "status": "active",
  717. "current_period_end": 1674268231,
  718. "cancel_at": 1674299999,
  719. "items": {
  720. "data": [
  721. {
  722. "price": {
  723. "id": "price_1234",
  724. "recurring": {
  725. "interval": "year"
  726. }
  727. }
  728. }
  729. ]
  730. }
  731. }
  732. }
  733. }`
  734. const subscriptionDeletedEventJSON = `
  735. {
  736. "type": "customer.subscription.deleted",
  737. "data": {
  738. "object": {
  739. "id": "sub_1234",
  740. "customer": "acct_5555",
  741. "status": "active",
  742. "current_period_end": 1674268231,
  743. "cancel_at": 1674299999,
  744. "items": {
  745. "data": [
  746. {
  747. "price": {
  748. "id": "price_1234",
  749. "recurring": {
  750. "interval": "month"
  751. }
  752. }
  753. }
  754. ]
  755. }
  756. }
  757. }
  758. }`