server_payments_test.go 20 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596
  1. package server
  2. import (
  3. "encoding/json"
  4. "github.com/stretchr/testify/mock"
  5. "github.com/stretchr/testify/require"
  6. "github.com/stripe/stripe-go/v74"
  7. "golang.org/x/time/rate"
  8. "heckel.io/ntfy/user"
  9. "heckel.io/ntfy/util"
  10. "io"
  11. "net/netip"
  12. "path/filepath"
  13. "strings"
  14. "sync"
  15. "testing"
  16. "time"
  17. )
  18. func TestPayments_Tiers(t *testing.T) {
  19. stripeMock := &testStripeAPI{}
  20. defer stripeMock.AssertExpectations(t)
  21. c := newTestConfigWithAuthFile(t)
  22. c.StripeSecretKey = "secret key"
  23. c.StripeWebhookKey = "webhook key"
  24. c.VisitorRequestLimitReplenish = 12 * time.Hour
  25. c.CacheDuration = 13 * time.Hour
  26. c.AttachmentFileSizeLimit = 111
  27. c.VisitorAttachmentTotalSizeLimit = 222
  28. c.AttachmentExpiryDuration = 123 * time.Second
  29. s := newTestServer(t, c)
  30. s.stripe = stripeMock
  31. // Define how the mock should react
  32. stripeMock.
  33. On("ListPrices", mock.Anything).
  34. Return([]*stripe.Price{
  35. {ID: "price_123", UnitAmount: 500},
  36. {ID: "price_456", UnitAmount: 1000},
  37. {ID: "price_999", UnitAmount: 9999},
  38. }, nil)
  39. // Create tiers
  40. require.Nil(t, s.userManager.CreateTier(&user.Tier{
  41. ID: "ti_1",
  42. Code: "admin",
  43. Name: "Admin",
  44. }))
  45. require.Nil(t, s.userManager.CreateTier(&user.Tier{
  46. ID: "ti_123",
  47. Code: "pro",
  48. Name: "Pro",
  49. MessageLimit: 1000,
  50. MessageExpiryDuration: time.Hour,
  51. EmailLimit: 123,
  52. ReservationLimit: 777,
  53. AttachmentFileSizeLimit: 999,
  54. AttachmentTotalSizeLimit: 888,
  55. AttachmentExpiryDuration: time.Minute,
  56. StripePriceID: "price_123",
  57. }))
  58. require.Nil(t, s.userManager.CreateTier(&user.Tier{
  59. ID: "ti_444",
  60. Code: "business",
  61. Name: "Business",
  62. MessageLimit: 2000,
  63. MessageExpiryDuration: 10 * time.Hour,
  64. EmailLimit: 123123,
  65. ReservationLimit: 777333,
  66. AttachmentFileSizeLimit: 999111,
  67. AttachmentTotalSizeLimit: 888111,
  68. AttachmentExpiryDuration: time.Hour,
  69. StripePriceID: "price_456",
  70. }))
  71. response := request(t, s, "GET", "/v1/tiers", "", nil)
  72. require.Equal(t, 200, response.Code)
  73. var tiers []apiAccountBillingTier
  74. require.Nil(t, json.NewDecoder(response.Body).Decode(&tiers))
  75. require.Equal(t, 3, len(tiers))
  76. // Free tier
  77. tier := tiers[0]
  78. require.Equal(t, "", tier.Code)
  79. require.Equal(t, "", tier.Name)
  80. require.Equal(t, "ip", tier.Limits.Basis)
  81. require.Equal(t, int64(0), tier.Limits.Reservations)
  82. require.Equal(t, int64(2), tier.Limits.Messages) // :-(
  83. require.Equal(t, int64(13*3600), tier.Limits.MessagesExpiryDuration)
  84. require.Equal(t, int64(24), tier.Limits.Emails)
  85. require.Equal(t, int64(111), tier.Limits.AttachmentFileSize)
  86. require.Equal(t, int64(222), tier.Limits.AttachmentTotalSize)
  87. require.Equal(t, int64(123), tier.Limits.AttachmentExpiryDuration)
  88. // Admin tier is not included, because it is not paid!
  89. tier = tiers[1]
  90. require.Equal(t, "pro", tier.Code)
  91. require.Equal(t, "Pro", tier.Name)
  92. require.Equal(t, "tier", tier.Limits.Basis)
  93. require.Equal(t, int64(777), tier.Limits.Reservations)
  94. require.Equal(t, int64(1000), tier.Limits.Messages)
  95. require.Equal(t, int64(3600), tier.Limits.MessagesExpiryDuration)
  96. require.Equal(t, int64(123), tier.Limits.Emails)
  97. require.Equal(t, int64(999), tier.Limits.AttachmentFileSize)
  98. require.Equal(t, int64(888), tier.Limits.AttachmentTotalSize)
  99. require.Equal(t, int64(60), tier.Limits.AttachmentExpiryDuration)
  100. tier = tiers[2]
  101. require.Equal(t, "business", tier.Code)
  102. require.Equal(t, "Business", tier.Name)
  103. require.Equal(t, "tier", tier.Limits.Basis)
  104. require.Equal(t, int64(777333), tier.Limits.Reservations)
  105. require.Equal(t, int64(2000), tier.Limits.Messages)
  106. require.Equal(t, int64(36000), tier.Limits.MessagesExpiryDuration)
  107. require.Equal(t, int64(123123), tier.Limits.Emails)
  108. require.Equal(t, int64(999111), tier.Limits.AttachmentFileSize)
  109. require.Equal(t, int64(888111), tier.Limits.AttachmentTotalSize)
  110. require.Equal(t, int64(3600), tier.Limits.AttachmentExpiryDuration)
  111. }
  112. func TestPayments_SubscriptionCreate_NotAStripeCustomer_Success(t *testing.T) {
  113. stripeMock := &testStripeAPI{}
  114. defer stripeMock.AssertExpectations(t)
  115. c := newTestConfigWithAuthFile(t)
  116. c.StripeSecretKey = "secret key"
  117. c.StripeWebhookKey = "webhook key"
  118. s := newTestServer(t, c)
  119. s.stripe = stripeMock
  120. // Define how the mock should react
  121. stripeMock.
  122. On("NewCheckoutSession", mock.Anything).
  123. Return(&stripe.CheckoutSession{URL: "https://billing.stripe.com/abc/def"}, nil)
  124. // Create tier and user
  125. require.Nil(t, s.userManager.CreateTier(&user.Tier{
  126. ID: "ti_123",
  127. Code: "pro",
  128. StripePriceID: "price_123",
  129. }))
  130. require.Nil(t, s.userManager.AddUser("phil", "phil", user.RoleUser))
  131. // Create subscription
  132. response := request(t, s, "POST", "/v1/account/billing/subscription", `{"tier": "pro"}`, map[string]string{
  133. "Authorization": util.BasicAuth("phil", "phil"),
  134. })
  135. require.Equal(t, 200, response.Code)
  136. redirectResponse, err := util.UnmarshalJSON[apiAccountBillingSubscriptionCreateResponse](io.NopCloser(response.Body))
  137. require.Nil(t, err)
  138. require.Equal(t, "https://billing.stripe.com/abc/def", redirectResponse.RedirectURL)
  139. }
  140. func TestPayments_SubscriptionCreate_StripeCustomer_Success(t *testing.T) {
  141. stripeMock := &testStripeAPI{}
  142. defer stripeMock.AssertExpectations(t)
  143. c := newTestConfigWithAuthFile(t)
  144. c.StripeSecretKey = "secret key"
  145. c.StripeWebhookKey = "webhook key"
  146. s := newTestServer(t, c)
  147. s.stripe = stripeMock
  148. // Define how the mock should react
  149. stripeMock.
  150. On("GetCustomer", "acct_123").
  151. Return(&stripe.Customer{Subscriptions: &stripe.SubscriptionList{}}, nil)
  152. stripeMock.
  153. On("NewCheckoutSession", mock.Anything).
  154. Return(&stripe.CheckoutSession{URL: "https://billing.stripe.com/abc/def"}, nil)
  155. // Create tier and user
  156. require.Nil(t, s.userManager.CreateTier(&user.Tier{
  157. ID: "ti_123",
  158. Code: "pro",
  159. StripePriceID: "price_123",
  160. }))
  161. require.Nil(t, s.userManager.AddUser("phil", "phil", user.RoleUser))
  162. u, err := s.userManager.User("phil")
  163. require.Nil(t, err)
  164. billing := &user.Billing{
  165. StripeCustomerID: "acct_123",
  166. }
  167. require.Nil(t, s.userManager.ChangeBilling(u.Name, billing))
  168. // Create subscription
  169. response := request(t, s, "POST", "/v1/account/billing/subscription", `{"tier": "pro"}`, map[string]string{
  170. "Authorization": util.BasicAuth("phil", "phil"),
  171. })
  172. require.Equal(t, 200, response.Code)
  173. redirectResponse, err := util.UnmarshalJSON[apiAccountBillingSubscriptionCreateResponse](io.NopCloser(response.Body))
  174. require.Nil(t, err)
  175. require.Equal(t, "https://billing.stripe.com/abc/def", redirectResponse.RedirectURL)
  176. }
  177. func TestPayments_AccountDelete_Cancels_Subscription(t *testing.T) {
  178. stripeMock := &testStripeAPI{}
  179. defer stripeMock.AssertExpectations(t)
  180. c := newTestConfigWithAuthFile(t)
  181. c.EnableSignup = true
  182. c.StripeSecretKey = "secret key"
  183. c.StripeWebhookKey = "webhook key"
  184. s := newTestServer(t, c)
  185. s.stripe = stripeMock
  186. // Define how the mock should react
  187. stripeMock.
  188. On("CancelSubscription", "sub_123").
  189. Return(&stripe.Subscription{}, nil)
  190. // Create tier and user
  191. require.Nil(t, s.userManager.CreateTier(&user.Tier{
  192. ID: "ti_123",
  193. Code: "pro",
  194. StripePriceID: "price_123",
  195. }))
  196. require.Nil(t, s.userManager.AddUser("phil", "phil", user.RoleUser))
  197. u, err := s.userManager.User("phil")
  198. require.Nil(t, err)
  199. billing := &user.Billing{
  200. StripeCustomerID: "acct_123",
  201. StripeSubscriptionID: "sub_123",
  202. }
  203. require.Nil(t, s.userManager.ChangeBilling(u.Name, billing))
  204. // Delete account
  205. rr := request(t, s, "DELETE", "/v1/account", `{"password": "phil"}`, map[string]string{
  206. "Authorization": util.BasicAuth("phil", "phil"),
  207. })
  208. require.Equal(t, 200, rr.Code)
  209. rr = request(t, s, "GET", "/v1/account", "", map[string]string{
  210. "Authorization": util.BasicAuth("phil", "mypass"),
  211. })
  212. require.Equal(t, 401, rr.Code)
  213. }
  214. func TestPayments_Checkout_Success_And_Increase_Rate_Limits_Reset_Visitor(t *testing.T) {
  215. // This test is too overloaded, but it's also a great end-to-end a test.
  216. //
  217. // It tests:
  218. // - A successful checkout flow (not a paying customer -> paying customer)
  219. // - Tier-changes reset the rate limits for the user
  220. // - The request limits for tier-less user and a tier-user
  221. // - The message limits for a tier-user
  222. stripeMock := &testStripeAPI{}
  223. defer stripeMock.AssertExpectations(t)
  224. c := newTestConfigWithAuthFile(t)
  225. c.StripeSecretKey = "secret key"
  226. c.StripeWebhookKey = "webhook key"
  227. c.VisitorRequestLimitBurst = 5
  228. c.VisitorRequestLimitReplenish = time.Hour
  229. c.CacheStartupQueries = `
  230. pragma journal_mode = WAL;
  231. pragma synchronous = normal;
  232. pragma temp_store = memory;
  233. `
  234. c.CacheBatchSize = 500
  235. c.CacheBatchTimeout = time.Second
  236. s := newTestServer(t, c)
  237. s.stripe = stripeMock
  238. // Create a user with a Stripe subscription and 3 reservations
  239. require.Nil(t, s.userManager.CreateTier(&user.Tier{
  240. ID: "ti_123",
  241. Code: "starter",
  242. StripePriceID: "price_1234",
  243. ReservationLimit: 1,
  244. MessageLimit: 220, // 220 * 5% = 11 requests before rate limiting kicks in
  245. MessageExpiryDuration: time.Hour,
  246. }))
  247. require.Nil(t, s.userManager.AddUser("phil", "phil", user.RoleUser)) // No tier
  248. u, err := s.userManager.User("phil")
  249. require.Nil(t, err)
  250. // Define how the mock should react
  251. stripeMock.
  252. On("GetSession", "SOMETOKEN").
  253. Return(&stripe.CheckoutSession{
  254. ClientReferenceID: u.ID, // ntfy user ID
  255. Customer: &stripe.Customer{
  256. ID: "acct_5555",
  257. },
  258. Subscription: &stripe.Subscription{
  259. ID: "sub_1234",
  260. },
  261. }, nil)
  262. stripeMock.
  263. On("GetSubscription", "sub_1234").
  264. Return(&stripe.Subscription{
  265. ID: "sub_1234",
  266. Status: stripe.SubscriptionStatusActive,
  267. CurrentPeriodEnd: 123456789,
  268. CancelAt: 0,
  269. Items: &stripe.SubscriptionItemList{
  270. Data: []*stripe.SubscriptionItem{
  271. {
  272. Price: &stripe.Price{ID: "price_1234"},
  273. },
  274. },
  275. },
  276. }, nil)
  277. stripeMock.
  278. On("UpdateCustomer", mock.Anything).
  279. Return(&stripe.Customer{}, nil)
  280. // Send messages until rate limit of free tier is hit
  281. for i := 0; i < 5; i++ {
  282. rr := request(t, s, "PUT", "/mytopic", "some message", map[string]string{
  283. "Authorization": util.BasicAuth("phil", "phil"),
  284. })
  285. require.Equal(t, 200, rr.Code)
  286. }
  287. rr := request(t, s, "PUT", "/mytopic", "some message", map[string]string{
  288. "Authorization": util.BasicAuth("phil", "phil"),
  289. })
  290. require.Equal(t, 429, rr.Code)
  291. // Simulate Stripe success return URL call (no user context)
  292. rr = request(t, s, "GET", "/v1/account/billing/subscription/success/SOMETOKEN", "", nil)
  293. require.Equal(t, 303, rr.Code)
  294. // Verify that database columns were updated
  295. u, err = s.userManager.User("phil")
  296. require.Nil(t, err)
  297. require.Equal(t, "starter", u.Tier.Code) // Not "pro"
  298. require.Equal(t, "acct_5555", u.Billing.StripeCustomerID)
  299. require.Equal(t, "sub_1234", u.Billing.StripeSubscriptionID)
  300. require.Equal(t, stripe.SubscriptionStatusActive, u.Billing.StripeSubscriptionStatus)
  301. require.Equal(t, int64(123456789), u.Billing.StripeSubscriptionPaidUntil.Unix())
  302. require.Equal(t, int64(0), u.Billing.StripeSubscriptionCancelAt.Unix())
  303. // Now for the fun part: Verify that new rate limits are immediately applied
  304. // This only tests the request limiter, which kicks in before the message limiter.
  305. for i := 0; i < 11; i++ {
  306. rr := request(t, s, "PUT", "/mytopic", "some message", map[string]string{
  307. "Authorization": util.BasicAuth("phil", "phil"),
  308. })
  309. require.Equal(t, 200, rr.Code, "failed on iteration %d", i)
  310. }
  311. rr = request(t, s, "PUT", "/mytopic", "some message", map[string]string{
  312. "Authorization": util.BasicAuth("phil", "phil"),
  313. })
  314. require.Equal(t, 429, rr.Code)
  315. // Now let's test the message limiter by faking a ridiculously generous rate limiter
  316. v := s.visitor(netip.MustParseAddr("9.9.9.9"), u)
  317. v.requestLimiter = rate.NewLimiter(rate.Every(time.Millisecond), 1000000)
  318. var wg sync.WaitGroup
  319. for i := 0; i < 209; i++ {
  320. wg.Add(1)
  321. go func() {
  322. rr := request(t, s, "PUT", "/mytopic", "some message", map[string]string{
  323. "Authorization": util.BasicAuth("phil", "phil"),
  324. })
  325. require.Equal(t, 200, rr.Code)
  326. wg.Done()
  327. }()
  328. }
  329. wg.Wait()
  330. rr = request(t, s, "PUT", "/mytopic", "some message", map[string]string{
  331. "Authorization": util.BasicAuth("phil", "phil"),
  332. })
  333. require.Equal(t, 429, rr.Code)
  334. // And now let's cross-check that the stats are correct too
  335. rr = request(t, s, "GET", "/v1/account", "", map[string]string{
  336. "Authorization": util.BasicAuth("phil", "phil"),
  337. })
  338. require.Equal(t, 200, rr.Code)
  339. account, _ := util.UnmarshalJSON[apiAccountResponse](io.NopCloser(rr.Body))
  340. require.Equal(t, int64(220), account.Limits.Messages)
  341. require.Equal(t, int64(220), account.Stats.Messages)
  342. require.Equal(t, int64(0), account.Stats.MessagesRemaining)
  343. }
  344. func TestPayments_Webhook_Subscription_Updated_Downgrade_From_PastDue_To_Active(t *testing.T) {
  345. // This tests incoming webhooks from Stripe to update a subscription:
  346. // - All Stripe columns are updated in the user table
  347. // - When downgrading, excess reservations are deleted, including messages and attachments in
  348. // the corresponding topics
  349. stripeMock := &testStripeAPI{}
  350. defer stripeMock.AssertExpectations(t)
  351. c := newTestConfigWithAuthFile(t)
  352. c.StripeSecretKey = "secret key"
  353. c.StripeWebhookKey = "webhook key"
  354. s := newTestServer(t, c)
  355. s.stripe = stripeMock
  356. // Define how the mock should react
  357. stripeMock.
  358. On("ConstructWebhookEvent", mock.Anything, "stripe signature", "webhook key").
  359. Return(jsonToStripeEvent(t, subscriptionUpdatedEventJSON), nil)
  360. // Create a user with a Stripe subscription and 3 reservations
  361. require.Nil(t, s.userManager.CreateTier(&user.Tier{
  362. ID: "ti_1",
  363. Code: "starter",
  364. StripePriceID: "price_1234", // !
  365. ReservationLimit: 1, // !
  366. MessageLimit: 100,
  367. MessageExpiryDuration: time.Hour,
  368. AttachmentExpiryDuration: time.Hour,
  369. AttachmentFileSizeLimit: 1000000,
  370. AttachmentTotalSizeLimit: 1000000,
  371. AttachmentBandwidthLimit: 1000000,
  372. }))
  373. require.Nil(t, s.userManager.CreateTier(&user.Tier{
  374. ID: "ti_2",
  375. Code: "pro",
  376. StripePriceID: "price_1111", // !
  377. ReservationLimit: 3, // !
  378. MessageLimit: 200,
  379. MessageExpiryDuration: time.Hour,
  380. AttachmentExpiryDuration: time.Hour,
  381. AttachmentFileSizeLimit: 1000000,
  382. AttachmentTotalSizeLimit: 1000000,
  383. AttachmentBandwidthLimit: 1000000,
  384. }))
  385. require.Nil(t, s.userManager.AddUser("phil", "phil", user.RoleUser))
  386. require.Nil(t, s.userManager.ChangeTier("phil", "pro"))
  387. require.Nil(t, s.userManager.AddReservation("phil", "atopic", user.PermissionDenyAll))
  388. require.Nil(t, s.userManager.AddReservation("phil", "ztopic", user.PermissionDenyAll))
  389. // Add billing details
  390. u, err := s.userManager.User("phil")
  391. require.Nil(t, err)
  392. billing := &user.Billing{
  393. StripeCustomerID: "acct_5555",
  394. StripeSubscriptionID: "sub_1234",
  395. StripeSubscriptionStatus: stripe.SubscriptionStatusPastDue,
  396. StripeSubscriptionPaidUntil: time.Unix(123, 0),
  397. StripeSubscriptionCancelAt: time.Unix(456, 0),
  398. }
  399. require.Nil(t, s.userManager.ChangeBilling(u.Name, billing))
  400. // Add some messages to "atopic" and "ztopic", everything in "ztopic" will be deleted
  401. rr := request(t, s, "PUT", "/atopic", "some aaa message", map[string]string{
  402. "Authorization": util.BasicAuth("phil", "phil"),
  403. })
  404. require.Equal(t, 200, rr.Code)
  405. rr = request(t, s, "PUT", "/atopic", strings.Repeat("a", 5000), map[string]string{
  406. "Authorization": util.BasicAuth("phil", "phil"),
  407. })
  408. require.Equal(t, 200, rr.Code)
  409. a2 := toMessage(t, rr.Body.String())
  410. require.FileExists(t, filepath.Join(s.config.AttachmentCacheDir, a2.ID))
  411. rr = request(t, s, "PUT", "/ztopic", "some zzz message", map[string]string{
  412. "Authorization": util.BasicAuth("phil", "phil"),
  413. })
  414. require.Equal(t, 200, rr.Code)
  415. rr = request(t, s, "PUT", "/ztopic", strings.Repeat("z", 5000), map[string]string{
  416. "Authorization": util.BasicAuth("phil", "phil"),
  417. })
  418. require.Equal(t, 200, rr.Code)
  419. z2 := toMessage(t, rr.Body.String())
  420. require.FileExists(t, filepath.Join(s.config.AttachmentCacheDir, z2.ID))
  421. // Call the webhook: This does all the magic
  422. rr = request(t, s, "POST", "/v1/account/billing/webhook", "dummy", map[string]string{
  423. "Stripe-Signature": "stripe signature",
  424. })
  425. require.Equal(t, 200, rr.Code)
  426. // Verify that database columns were updated
  427. u, err = s.userManager.User("phil")
  428. require.Nil(t, err)
  429. require.Equal(t, "starter", u.Tier.Code) // Not "pro"
  430. require.Equal(t, "acct_5555", u.Billing.StripeCustomerID)
  431. require.Equal(t, "sub_1234", u.Billing.StripeSubscriptionID)
  432. require.Equal(t, stripe.SubscriptionStatusActive, u.Billing.StripeSubscriptionStatus) // Not "past_due"
  433. require.Equal(t, int64(1674268231), u.Billing.StripeSubscriptionPaidUntil.Unix()) // Updated
  434. require.Equal(t, int64(1674299999), u.Billing.StripeSubscriptionCancelAt.Unix()) // Updated
  435. // Verify that reservations were deleted
  436. r, err := s.userManager.Reservations("phil")
  437. require.Nil(t, err)
  438. require.Equal(t, 1, len(r)) // "ztopic" reservation was deleted
  439. require.Equal(t, "atopic", r[0].Topic)
  440. // Verify that messages and attachments were deleted
  441. time.Sleep(time.Second)
  442. s.execManager()
  443. ms, err := s.messageCache.Messages("atopic", sinceAllMessages, false)
  444. require.Nil(t, err)
  445. require.Equal(t, 2, len(ms))
  446. require.FileExists(t, filepath.Join(s.config.AttachmentCacheDir, a2.ID))
  447. ms, err = s.messageCache.Messages("ztopic", sinceAllMessages, false)
  448. require.Nil(t, err)
  449. require.Equal(t, 0, len(ms))
  450. require.NoFileExists(t, filepath.Join(s.config.AttachmentCacheDir, z2.ID))
  451. }
  452. type testStripeAPI struct {
  453. mock.Mock
  454. }
  455. var _ stripeAPI = (*testStripeAPI)(nil)
  456. func (s *testStripeAPI) NewCheckoutSession(params *stripe.CheckoutSessionParams) (*stripe.CheckoutSession, error) {
  457. args := s.Called(params)
  458. return args.Get(0).(*stripe.CheckoutSession), args.Error(1)
  459. }
  460. func (s *testStripeAPI) NewPortalSession(params *stripe.BillingPortalSessionParams) (*stripe.BillingPortalSession, error) {
  461. args := s.Called(params)
  462. return args.Get(0).(*stripe.BillingPortalSession), args.Error(1)
  463. }
  464. func (s *testStripeAPI) ListPrices(params *stripe.PriceListParams) ([]*stripe.Price, error) {
  465. args := s.Called(params)
  466. return args.Get(0).([]*stripe.Price), args.Error(1)
  467. }
  468. func (s *testStripeAPI) GetCustomer(id string) (*stripe.Customer, error) {
  469. args := s.Called(id)
  470. return args.Get(0).(*stripe.Customer), args.Error(1)
  471. }
  472. func (s *testStripeAPI) GetSession(id string) (*stripe.CheckoutSession, error) {
  473. args := s.Called(id)
  474. return args.Get(0).(*stripe.CheckoutSession), args.Error(1)
  475. }
  476. func (s *testStripeAPI) GetSubscription(id string) (*stripe.Subscription, error) {
  477. args := s.Called(id)
  478. return args.Get(0).(*stripe.Subscription), args.Error(1)
  479. }
  480. func (s *testStripeAPI) UpdateCustomer(id string, params *stripe.CustomerParams) (*stripe.Customer, error) {
  481. args := s.Called(id)
  482. return args.Get(0).(*stripe.Customer), args.Error(1)
  483. }
  484. func (s *testStripeAPI) UpdateSubscription(id string, params *stripe.SubscriptionParams) (*stripe.Subscription, error) {
  485. args := s.Called(id)
  486. return args.Get(0).(*stripe.Subscription), args.Error(1)
  487. }
  488. func (s *testStripeAPI) CancelSubscription(id string) (*stripe.Subscription, error) {
  489. args := s.Called(id)
  490. return args.Get(0).(*stripe.Subscription), args.Error(1)
  491. }
  492. func (s *testStripeAPI) ConstructWebhookEvent(payload []byte, header string, secret string) (stripe.Event, error) {
  493. args := s.Called(payload, header, secret)
  494. return args.Get(0).(stripe.Event), args.Error(1)
  495. }
  496. func jsonToStripeEvent(t *testing.T, v string) stripe.Event {
  497. var e stripe.Event
  498. if err := json.Unmarshal([]byte(v), &e); err != nil {
  499. t.Fatal(err)
  500. }
  501. return e
  502. }
  503. const subscriptionUpdatedEventJSON = `
  504. {
  505. "type": "customer.subscription.updated",
  506. "data": {
  507. "object": {
  508. "id": "sub_1234",
  509. "customer": "acct_5555",
  510. "status": "active",
  511. "current_period_end": 1674268231,
  512. "cancel_at": 1674299999,
  513. "items": {
  514. "data": [
  515. {
  516. "price": {
  517. "id": "price_1234"
  518. }
  519. }
  520. ]
  521. }
  522. }
  523. }
  524. }`