server_payments_test.go 5.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177
  1. package server
  2. import (
  3. "github.com/stretchr/testify/mock"
  4. "github.com/stretchr/testify/require"
  5. "github.com/stripe/stripe-go/v74"
  6. "heckel.io/ntfy/user"
  7. "heckel.io/ntfy/util"
  8. "io"
  9. "testing"
  10. )
  11. func TestPayments_SubscriptionCreate_NotAStripeCustomer_Success(t *testing.T) {
  12. stripeMock := &testStripeAPI{}
  13. defer stripeMock.AssertExpectations(t)
  14. c := newTestConfigWithAuthFile(t)
  15. c.StripeSecretKey = "secret key"
  16. c.StripeWebhookKey = "webhook key"
  17. s := newTestServer(t, c)
  18. s.stripe = stripeMock
  19. // Define how the mock should react
  20. stripeMock.
  21. On("NewCheckoutSession", mock.Anything).
  22. Return(&stripe.CheckoutSession{URL: "https://billing.stripe.com/abc/def"}, nil)
  23. // Create tier and user
  24. require.Nil(t, s.userManager.CreateTier(&user.Tier{
  25. Code: "pro",
  26. StripePriceID: "price_123",
  27. }))
  28. require.Nil(t, s.userManager.AddUser("phil", "phil", user.RoleUser, "unit-test"))
  29. // Create subscription
  30. response := request(t, s, "POST", "/v1/account/billing/subscription", `{"tier": "pro"}`, map[string]string{
  31. "Authorization": util.BasicAuth("phil", "phil"),
  32. })
  33. require.Equal(t, 200, response.Code)
  34. redirectResponse, err := util.UnmarshalJSON[apiAccountBillingSubscriptionCreateResponse](io.NopCloser(response.Body))
  35. require.Nil(t, err)
  36. require.Equal(t, "https://billing.stripe.com/abc/def", redirectResponse.RedirectURL)
  37. }
  38. func TestPayments_SubscriptionCreate_StripeCustomer_Success(t *testing.T) {
  39. stripeMock := &testStripeAPI{}
  40. defer stripeMock.AssertExpectations(t)
  41. c := newTestConfigWithAuthFile(t)
  42. c.StripeSecretKey = "secret key"
  43. c.StripeWebhookKey = "webhook key"
  44. s := newTestServer(t, c)
  45. s.stripe = stripeMock
  46. // Define how the mock should react
  47. stripeMock.
  48. On("GetCustomer", "acct_123").
  49. Return(&stripe.Customer{Subscriptions: &stripe.SubscriptionList{}}, nil)
  50. stripeMock.
  51. On("NewCheckoutSession", mock.Anything).
  52. Return(&stripe.CheckoutSession{URL: "https://billing.stripe.com/abc/def"}, nil)
  53. // Create tier and user
  54. require.Nil(t, s.userManager.CreateTier(&user.Tier{
  55. Code: "pro",
  56. StripePriceID: "price_123",
  57. }))
  58. require.Nil(t, s.userManager.AddUser("phil", "phil", user.RoleUser, "unit-test"))
  59. u, err := s.userManager.User("phil")
  60. require.Nil(t, err)
  61. u.Billing.StripeCustomerID = "acct_123"
  62. require.Nil(t, s.userManager.ChangeBilling(u))
  63. // Create subscription
  64. response := request(t, s, "POST", "/v1/account/billing/subscription", `{"tier": "pro"}`, map[string]string{
  65. "Authorization": util.BasicAuth("phil", "phil"),
  66. })
  67. require.Equal(t, 200, response.Code)
  68. redirectResponse, err := util.UnmarshalJSON[apiAccountBillingSubscriptionCreateResponse](io.NopCloser(response.Body))
  69. require.Nil(t, err)
  70. require.Equal(t, "https://billing.stripe.com/abc/def", redirectResponse.RedirectURL)
  71. }
  72. func TestPayments_AccountDelete_Cancels_Subscription(t *testing.T) {
  73. stripeMock := &testStripeAPI{}
  74. defer stripeMock.AssertExpectations(t)
  75. c := newTestConfigWithAuthFile(t)
  76. c.EnableSignup = true
  77. c.StripeSecretKey = "secret key"
  78. c.StripeWebhookKey = "webhook key"
  79. s := newTestServer(t, c)
  80. s.stripe = stripeMock
  81. // Define how the mock should react
  82. stripeMock.
  83. On("CancelSubscription", "sub_123").
  84. Return(&stripe.Subscription{}, nil)
  85. // Create tier and user
  86. require.Nil(t, s.userManager.CreateTier(&user.Tier{
  87. Code: "pro",
  88. StripePriceID: "price_123",
  89. }))
  90. require.Nil(t, s.userManager.AddUser("phil", "phil", user.RoleUser, "unit-test"))
  91. u, err := s.userManager.User("phil")
  92. require.Nil(t, err)
  93. u.Billing.StripeCustomerID = "acct_123"
  94. u.Billing.StripeSubscriptionID = "sub_123"
  95. require.Nil(t, s.userManager.ChangeBilling(u))
  96. // Delete account
  97. rr := request(t, s, "DELETE", "/v1/account", "", map[string]string{
  98. "Authorization": util.BasicAuth("phil", "phil"),
  99. })
  100. require.Equal(t, 200, rr.Code)
  101. rr = request(t, s, "GET", "/v1/account", "", map[string]string{
  102. "Authorization": util.BasicAuth("phil", "mypass"),
  103. })
  104. require.Equal(t, 401, rr.Code)
  105. }
  106. type testStripeAPI struct {
  107. mock.Mock
  108. }
  109. func (s *testStripeAPI) NewCheckoutSession(params *stripe.CheckoutSessionParams) (*stripe.CheckoutSession, error) {
  110. args := s.Called(params)
  111. return args.Get(0).(*stripe.CheckoutSession), args.Error(1)
  112. }
  113. func (s *testStripeAPI) NewPortalSession(params *stripe.BillingPortalSessionParams) (*stripe.BillingPortalSession, error) {
  114. args := s.Called(params)
  115. return args.Get(0).(*stripe.BillingPortalSession), args.Error(1)
  116. }
  117. func (s *testStripeAPI) ListPrices(params *stripe.PriceListParams) ([]*stripe.Price, error) {
  118. args := s.Called(params)
  119. return args.Get(0).([]*stripe.Price), args.Error(1)
  120. }
  121. func (s *testStripeAPI) GetCustomer(id string) (*stripe.Customer, error) {
  122. args := s.Called(id)
  123. return args.Get(0).(*stripe.Customer), args.Error(1)
  124. }
  125. func (s *testStripeAPI) GetSession(id string) (*stripe.CheckoutSession, error) {
  126. args := s.Called(id)
  127. return args.Get(0).(*stripe.CheckoutSession), args.Error(1)
  128. }
  129. func (s *testStripeAPI) GetSubscription(id string) (*stripe.Subscription, error) {
  130. args := s.Called(id)
  131. return args.Get(0).(*stripe.Subscription), args.Error(1)
  132. }
  133. func (s *testStripeAPI) UpdateSubscription(id string, params *stripe.SubscriptionParams) (*stripe.Subscription, error) {
  134. args := s.Called(id)
  135. return args.Get(0).(*stripe.Subscription), args.Error(1)
  136. }
  137. func (s *testStripeAPI) CancelSubscription(id string) (*stripe.Subscription, error) {
  138. args := s.Called(id)
  139. return args.Get(0).(*stripe.Subscription), args.Error(1)
  140. }
  141. func (s *testStripeAPI) ConstructWebhookEvent(payload []byte, header string, secret string) (stripe.Event, error) {
  142. args := s.Called(payload, header, secret)
  143. return args.Get(0).(stripe.Event), args.Error(1)
  144. }
  145. var _ stripeAPI = (*testStripeAPI)(nil)